目录
论文地址
ResNet :https://arxiv.org/pdf/1512.03385.pdf
ResNeXt: https://arxiv.org/abs/1611.05431
ResNet基本思想
在训练深层网络时,一般会遇到三个问题:
- 过拟合。这一点其实很好理解,因为训练的损失函数是训练集上的loss,当训练数据比较少,网络参数比较多的时候很容易为了减小训练集loss而过拟合。但过拟合问题不在本文的讨论范围之内,这里不做赘述。
- 梯度消失/爆炸。我们考虑一个简单的三层网络,
f
(
x
)
f(x)
f(x)为激活函数,
x
x
x为输入,
f
i
(
x
)
f_i(x)
fi(x)表示第
i
i
i层输出结果。那么首先,显然有
f
i
+
1
=
f
(
ω
i
⋅
f
i
+
b
)
∂
f
i
+
1
=
ω
i
⋅
∂
f
i
⋅
f
′
f_{i+1}=f(\omega_i\cdot f_i +b)\\ \partial{f_{i+1}}=\omega_i \cdot \partial{f_i}\cdot f^{'}
fi+1=f(ωi⋅fi+b)∂fi+1=ωi⋅∂fi⋅f′这里不妨令
b
=
0
b=0
b=0,那么根据梯度下降的原理,我们会对参数
ω
\omega
ω进行更新:
ω
i
:
=
ω
i
−
α
⋅
∂
L
∂
ω
i
\omega_i:=\omega_i-\alpha\cdot \dfrac{\partial{L}}{\partial{\omega_i}}
ωi:=ωi−α⋅∂ωi∂L以
ω
1
\omega_1
ω1为例,由链式法则我们有
∂
L
∂
ω
1
=
∂
L
∂
f
3
⋅
∂
f
3
∂
f
2
⋅
∂
f
2
∂
f
1
⋅
∂
f
1
∂
ω
1
=
∂
L
∂
f
3
⋅
∂
f
3
∂
f
2
⋅
ω
3
⋅
ω
3
⋅
∂
f
2
∂
f
1
⋅
ω
2
⋅
ω
2
⋅
∂
f
1
∂
ω
1
x
⋅
x
=
∂
L
∂
f
3
⋅
f
2
′
⋅
f
1
′
⋅
f
0
′
⋅
ω
3
⋅
ω
2
⋅
x
\begin{aligned} \dfrac{\partial{L}}{\partial\omega_1}&=\dfrac{\partial{L}}{\partial{f_3}}\cdot \dfrac{\partial{f_3}}{\partial{f_2}}\cdot \dfrac{\partial{f_2}}{\partial{f_1}}\cdot \dfrac{\partial{f_1}}{\partial{\omega_1}}\\ &=\dfrac{\partial{L}}{\partial{f_3}}\cdot \dfrac{\partial{f_3}}{\partial{f_2\cdot \omega_3}}\cdot \omega_3 \cdot \dfrac{\partial{f_2}}{\partial{f_1\cdot \omega_2}}\cdot \omega_2 \cdot \dfrac{\partial{f_1}}{\partial{\omega_1x}}\cdot x\\ &=\dfrac{\partial{L}}{\partial{f_3}}\cdot f^{'}_2\cdot f^{'}_1\cdot f^{'}_0 \cdot \omega_3 \cdot \omega_2\cdot x \end{aligned}
∂ω1∂L=∂f3∂L⋅∂f2∂f3⋅∂f1∂f2⋅∂ω1∂f1=∂f3∂L⋅∂f2⋅ω3∂f3⋅ω3⋅∂f1⋅ω2∂f2⋅ω2⋅∂ω1x∂f1⋅x=∂f3∂L⋅f2′⋅f1′⋅f0′⋅ω3⋅ω2⋅x
可以看到,最终反向传播的梯度主要由激活函数各节点处的导数 f i ′ f_i^{'} fi′和权重系数 ω i \omega_i ωi的连乘组成。因此,当网络层数很深,梯度就会因为这些连乘指数型增长/下降,从而导致梯度爆炸/消失。但是,这一问题已经在很多方面得到了一定解决,就拿激活函数来说,如果用ReLU作为激活函数,在大于0的部分导数恒为1,那么至少在激活函数导数连乘这一部分上就不会有很大的值;同时,ResNet的论文中也提到,批归一化(BatchNorm)也很大程度上解决了梯度爆炸/消失的问题,因为在反向传播的过程中,BN中的 σ \sigma σ项对权重系数作了一定的放缩,解决了权重部分的梯度爆炸/消失。 - 网络退化。这也是ResNet作者认为主要想解决的一个问题。论文中首先在同一数据集下做了一组关于网络层数的对照试验:
发现当网络层数增加后,训练集和测试集上的loss反而增加了,这在当时是很反直觉的。因为对于56层的网络,你只要让后36层都是恒等映射至少也可以达到和20层网络一样的效果,为什么效果反而会变差呢?但实际上,由于非线性激活函数的存在,想要令其成为恒等映射是几乎不可能的,那也就是说,一旦特征经过了试图成为恒等映射的层,都会对特征的信息进行一定损耗,从而导致深层的网络效果变差。
在知道上述问题后,作者给出了十分经典的Residual Block来解决退化问题:
从我们的直观理解上,Res模块最主要的一个操作就是他的一个跳跃连接,让输入进入两个分支,分支A如同其他层一样进行特征提取,分支B直接让原始输入与提取后的特征相加。这样一来,网络只要令
F
(
x
)
=
0
F(x)=0
F(x)=0就可以令Res模块达到恒等映射的效果。而这里的相加操作,本质上因为卷积是保加法的线性映射,等同于在后续层中用相同的卷积核参数对两个相加的特征进行操作。
作者在后文中使用Res模块设计了各种层数的ResNet网络(如下图),并对其效果进行测试。
其中,ResNet34和ResNet50的残差模块的数量一致,但是结构不同,ResNet50使用
n
n
n个
1
×
1
1\times1
1×1的卷积核接
n
n
n个
3
×
3
3\times3
3×3的卷积核再接上
4
n
4n
4n个
1
×
1
1\times1
1×1的卷积核替代两个级联的
n
n
n个
3
×
3
3\times3
3×3的卷积核。从结果上来看,ResNet50的表现优于ResNet34;从参数量上来看,我们不妨假设输入维数相同并忽略,则ResNet34中一个残差模块的参数量为
(
3
×
3
×
n
)
×
2
=
18
n
(3\times3\times n)\times2=18n
(3×3×n)×2=18n,ResNet50中一个残差模块的参数量为
(
1
×
1
×
n
+
3
×
3
×
n
+
1
×
1
×
4
n
)
=
14
n
(1\times1\times n +3\times 3 \times n+1\times1\times 4n)=14n
(1×1×n+3×3×n+1×1×4n)=14n。那为什么ResNet50的总参数量反而更高一点呢?实际上是因为在跳跃链接(ShortCut)时,有些层会遇到维数不对应的情况,这里的解决办法是通过
1
×
1
1\times1
1×1的卷积核扩增维数或是补0,因此由于残差模块的最后将通道数翻了四倍,ResNet50跳跃连接所需的参数量增加了很多,导致ResNet50总参数量比ResNet34略高一点。
ResNeXt基本思想
ResNeXt用分组的思想对ResNet(50)的残差模块进行了改进,如下图所示:
可以看到,ResNext50中将通道数翻倍,但是将卷积拆分成了32组,这样一来,我们不妨假设输入的长宽相同并忽略,那么图中模块的参数量可以计算为
(
256
×
1
×
1
×
4
)
×
32
+
(
4
×
3
×
3
×
4
)
×
32
+
(
4
×
1
×
1
×
256
)
×
32
=
53760
(256\times1\times1\times 4)\times 32+(4\times 3\times3\times4)\times 32+(4\times1\times1\times 256)\times32=53760
(256×1×1×4)×32+(4×3×3×4)×32+(4×1×1×256)×32=53760,而ResNet50中模块的参数量可计算为
(
256
×
1
×
1
×
64
)
+
(
64
×
3
×
3
×
64
)
+
(
64
×
1
×
1
×
256
)
=
69632
(256\times1\times1\times64)+(64\times3\times3\times64)+(64\times1\times1\times256)=69632
(256×1×1×64)+(64×3×3×64)+(64×1×1×256)=69632。因此可以看到下图中ResNeXt的参数量是比ResNet50的还要小的
同时,论文中还给出了ResNeXt残差模块的几种等价形式(如下图)
作者在提到,这三种结构是严格等价的,论文中实验采用的是第三种,因为运行的效率比较高。我们可以看到,第三种结构也说明了ResNeXt对于ResNet本质上的改进,实际就是通过中间的分组卷积减少了计算量,使得相比于ResNet通道数翻倍的情况下,计算量反而更小。同时,按照我的理解,分组卷积还能起到一定正则的作用,因为如果不对特征通道进行分组,全部的参数都用于训练一种过滤方式,如果参数过多而能提取到的特征又不够复杂的话很容易过拟合我。而进行分组之后,每个group希望学习到不同的特征(这一点在alexnet的实验中有印证),而对于每个group来说,参数量又比较小,不容易过拟合,因此才会有论文中分组数越多结果表现越好的情况,我个人认为是本身参数量对于实验的数据集就已经足够了,分组起到了一定的正则作用。