ResNet从理论到实践(一)ResNet原理

1.在ResNet出现之前

在2015年ResNet出现之前,CNN的结构大多如下图所示,通俗点说,用“卷积-maxpooling-ReLU”的方式,一卷到底,最后使用全连接层完成分类任务。

ResNet从理论到实践(一)ResNet原理

大家普遍认为,卷积神经网络的深度对于网络的性能起着至关重要的作用,所以普遍将网络深度从AlexNet的几层增加到十几层甚至更多,比如VGG16、VGG19,也正如人们所想,增加深度确实增加了模型的性能。

但深度继续增加时,网络的性能逐渐趋于饱和,甚至性能会出现随网络深度急剧下降的现象,一个有力的实验证明如下图所示:

ResNet从理论到实践(一)ResNet原理

上图为在CIFAR-10数据集上训练20层网络和56层网络的对比图,图中横轴为迭代次数,单位为 1 × 1 0 4 1 \times 10^4 1×104,左图纵轴为训练误差(train error),右图纵轴为测试误差(test error)。

左图表示随着迭代的进行,在训练集上完成分类任务的错误率变化;右图表示随着迭代的进行,在测试集上完成分类任务的错误率变化。

从上面2个图可以看出,无论是在训练集还是测试集,更深的网络(56层)错误率总是要高于浅层的网络(20层),即更深的网络性能更差。

ResNet的巧妙之处就是在卷积神经网络中加入“shortcut connection”,使得深层网络拥有比浅层网络更好的性能。

下面我们从残差学习开始讲起,逐渐引入“shortcut connection”的概念。

2.残差学习(Residual Learning)

考虑CNN中的一个小的网络块,网络块的输入为 x \mathbf{x} x,输出为 H ( x ) \mathcal{H}(\mathbf{x}) H(x),即这个网络块完成了非线性映射 H \mathcal{H} H。

作者设计了一个新的网络块,这个网络块的作用不是将输入特征 x \mathbf{x} x映射为 H ( x ) \mathcal{H}(\mathbf{x}) H(x),而是将其映射为 H ( x ) − x \mathcal{H}(\mathbf{x}) - \mathbf{x} H(x)−x,将 H ( x ) − x \mathcal{H}(\mathbf{x}) - \mathbf{x} H(x)−x记作 F ( x ) \mathcal{F}(\mathbf{x}) F(x),即该网络块计算原映射 H ( x ) \mathcal{H}(\mathbf{x}) H(x)与输入特征 x \mathbf{x} x的差值 F ( x ) \mathcal{F}(\mathbf{x}) F(x),也可以称作“使得该网络块学习原网络块 H ( x ) \mathcal{H}(\mathbf{x}) H(x)与输入特征 x \mathbf{x} x的残差”。如下图所示:

ResNet从理论到实践(一)ResNet原理

上图中由2个“weight layer”叠加而成的网络块用来完成从输入特征 x \mathbf{x} x到残差 F ( x ) \mathcal{F}(\mathbf{x}) F(x)的映射,将 F ( x ) \mathcal{F}(\mathbf{x}) F(x)与上图中右侧的shortcut connection 进行元素级别的加法操作,最终该网络块的输出为 F ( x ) + x \mathcal{F}(\mathbf{x})+\mathbf{x} F(x)+x,即上图中整个模块的输出仍为 H ( x ) \mathcal{H}(\mathbf{x}) H(x)。

简而言之,不再让“weight layer”输出最终的feature map,而是让“weight layer”输出最终feature map和输入特征的差值 F ( x ) \mathcal{F}(\mathbf{x}) F(x),将 F ( x ) \mathcal{F}(\mathbf{x}) F(x)与输入特征 x \mathbf{x} x进行元素级别加法操作得到最终的feature map。

上图中的结构可以表示为如下公式:

y = F ( x , { W i } ) + x \mathbf{y}=\mathcal{F}\left(\mathbf{x},\left\{W_{i}\right\}\right)+\mathbf{x} y=F(x,{Wi​})+x

上式中的 F = W 2 σ ( W 1 x ) \mathcal{F}=W_{2} \sigma\left(W_{1} \mathbf{x}\right) F=W2​σ(W1​x), σ \sigma σ表示ReLU操作。在完成 F ( x ) \mathcal{F}(\mathbf{x}) F(x)和 x \mathbf{x} x的元素加法后,上图中还进行了ReLU操作,即上图中结构的最终输出为 σ ( y ) \sigma(\mathbf{y}) σ(y)。

从上述操作可看出,shortcut connection的引入只在原来的基础上增加了元素加法操作,并没有引入大量的额外计算和可学习参数。

考虑一种情况:若在计算残差 F ( x ) \mathcal{F}(\mathbf{x}) F(x)时,除了使用卷积层还用了pooling操作,会使得 F ( x ) \mathcal{F}(\mathbf{x}) F(x)的尺寸与输入特征 x \mathbf{x} x的尺寸不一致,从而无法进行元素加法操作。为了解决这个问题,当 F ( x ) \mathcal{F}(\mathbf{x}) F(x)的尺寸与输入特征 x \mathbf{x} x的尺寸不一致时,使用如下公式代替上文中的公式:

y = F ( x , { W i } ) + W s x \mathbf{y}=\mathcal{F}\left(\mathbf{x},\left\{W_{i}\right\}\right)+W_{s} \mathbf{x} y=F(x,{Wi​})+Ws​x

即对输入特征 x \mathbf{x} x使用操作 W s W_{s} Ws​,使得 W s x W_{s} \mathbf{x} Ws​x的特征尺寸与 F ( x ) \mathcal{F}(\mathbf{x}) F(x)一致。在实际使用时, W s x W_{s} \mathbf{x} Ws​x表示在输入特征 x \mathbf{x} x上做步长为2的 1 × 1 1 \times 1 1×1卷积,使得其输出特征在尺寸和通道数上与 F ( x ) \mathcal{F}(\mathbf{x}) F(x)一致。

需要注意的是,上图中shortcut connection跨接了2个weight layer,实际在使用时,shortcut connection可以灵活跨接多个卷积层。

3.ResNet网络结构

基于上文中提到的结构,作者构建出5个不同结构的卷积神经网络用于ImageNet数据集分类,并根据它们的深度将它们分别命名为ResNet-18、ResNet-34、ResNet-50、ResNet-101、ResNet-152,这些网络的结构如下表所示:

ResNet从理论到实践(一)ResNet原理

在上表所示的5个网络中,conv3_1、conv4_1、conv5_1这3个层使用步长为2的卷积层实现下采样功能。

上面5个网络在结构上可以分为2类:一类为ResNet-18和ResNet-34,它们的基本组件为下图中左边的结构;另一类为ResNet-50、ResNet-101和ResNet152,它们的基本组件为下图中右边的结构:

ResNet从理论到实践(一)ResNet原理

上图中右侧的结构先使用 1 × 1 1 \times 1 1×1卷积降低特征通道数,使用 3 × 3 3 \times 3 3×3卷积完成特征提取,然后再使用 1 × 1 1 \times 1 1×1卷积增加特征通道数,在深层的网络使用这种结构可以减少计算量。

4.实验结果

4.1 CIFAR-10数据集

为便于分析使用普通卷积模块和使用残差模块构建的网络的性能差异,作者在CIFAR-10数据集上训练不同深度的网络,并比较它们的性能,结果如下图所示:

ResNet从理论到实践(一)ResNet原理

上图中左图为不同深度的传统卷积神经网络,右图为不同深度的残差卷积神经网络。横轴为迭代次数,单位为 1 × 1 0 4 1 \times 10^4 1×104,虚线表示训练误差,实线表示测试误差。

左图中,随着网络深度的增加,误差变大,性能变差;右图中,随着网络深度的增加,误差变小,性能变强。

因此,使用ResNet系列网络,能够解决增加网络深度导致的性能变差问题,使更深层的卷积层能够提取更高层次的特征,从而使得整个网络有更高的性能。

4.2 ImageNet数据集

在ImageNet训练集上训练ResNet网络,在验证集上测试,结果如下表所示:

ResNet从理论到实践(一)ResNet原理

从上表中可以看出,ResNet系列网络相比之前的VGG,有更好的性能,且深层网络的性能优于浅层网络。

5.总结

ResNet网络通过使用残差模块,解决了增加CNN深度引起的性能退化问题。通过构建更深的卷积神经网络,提取更高层次的特征,使其拥有更高的性能。

ResNet的出现极大地推进了卷积神经网络在计算机视觉各领域的应用,在多种计算机视觉任务中都可以看到ResNet的身影。

如果你对计算机视觉中的目标检测、跟踪、分割、轻量化神经网络感兴趣,欢迎关注公众号一起学习交流~
ResNet从理论到实践(一)ResNet原理

上一篇:机器学习笔记:LS、Ridge、Lasso、最小一乘法的选择过程推导


下一篇:线性代数学习笔记