U-Net++原作者博客总结笔记
U-Net++笔记
参考:
本文是根据原论文作者在某乎上的完美演讲进行的一些总结:
原文作者大牛的思路讲解:https://zhuanlan.zhihu.com/p/44958351(研究思路十分值得学习和借鉴)
要解决的问题:
原本U-Net的结构引出的问题
原本的结构用了4次下采样,但不得不问,四次下采样的结果是最好的吗?
答案是否定的,或者说,在U-Net作者所进行的task上,它表现的不错。
但换一个task呢?
所以这个下采样的次数事实上是和task有关的,或者说,和你需要多大的感受野有关,和你的task、dataset直接相关。
下采样有什么作用?
“它可以增加对输入图像的一些小扰动的鲁棒性,比如图像平移,旋转等,减少过拟合的风险,降低运算量,和增加感受野的大小。”
而且,在U-Net的结构中,是经过了3次卷积之后才进行了下采样,所以,4次下采样其实已经是比较深层的全局信息了,而前面1次、2次下采样可能还是比较浅层的局部信息,但这两个信息是同样重要的。
对于一些task,可能只需要1到2次的上采样就可以了,对于一些task,可能需要3次,4次甚至更多次,才有比较好的编码,最终得到比较好的模型。
如何去解决这些问题
你要去确定多少次下采样比较适合你的task,就需要尝试,通过数据来说话。
即你需要尝试以上四种情况,但这显然是不科学的,或者是不高效的。
去训练这四种网络,分别测试性能,这样的时间成本和资源成本是很高的
那么有什么办法呢?怎么在原本U-Net结构中加上这些test来构成一个完整的网络呢?
填空!!!
可以看到,U-Net网络结构中间是采用长连接的,中间的部分是空出来的,而这些1次,2次,3次下采样的网络结构恰可以填补这样的空缺,结构如下:
这样的结构就实现了4种下采样次数的综合,但仍存在一个问题:
图中红色区域,这些区域是没有连接到loss的,所以在反向传播的时候,误差是不会反向传播到这些区域的,也就是,这些区域是无法训练的!
如何去解决呢?
短连接替换长连接
即这样的拓扑结构,原本两个大V对应深度的长连接,被这样一根根的短连接所取代,这样的话,所有节点都连接到loss,就可以进行训练啦!
但仍存在一个问题:短连接完全替代长连接,这好吗?或者说,长连接的作用是什么?
事实上,我们再看一下原本U-Net的结构:
这个长连接其实就类似与ResNet残差神经网络里的residual block残差单元:
即实现了:x+f(x)的操作(这个f就是长连接的起点x到终点中间经历所有层次,抽象成了一个映射)
所以我认为它的作用和residual block的作用是相似的:
①防止梯度消失 ②一定程度上解决了网络深度给训练带来的困难 ③将浅层的局部信息和深层的全局信息做了一个综合,用浅层的信息来恢复一些细节性的信息
所以,长连接是十分有必要的,不应该被彻底删除
所以如何解决这个问题呢?
长连接、短连接并存
第一眼看到这个结构,我非常激动,为什么呢?
因为我前段时间看了DenseNet,那篇2017年的神论文
横向的看这个网络,每一行都是一个DenseNet的Densely Connection
原本的U-Net类似于ResNet
而现在的U-Net++则类似于DenseNet
作者后面也提到,U-Net++比U-Net在分割上的优越可能可以类比DenseNet比ResNet在分类上的优越
作者做到这里也发现了,他们和另一路人在研究同一个东西,对,没错,撞车啦!
但作者也有更优越的地方,将在后面介绍。
测试
U-Net++相对于原本的U-Net有一个很明显的问题:参数量变多了
所以作者将U-Net++和增加过Channels数目(更宽大的)U-Net进行性能的比较
证明了,参数多不一定就好,参数,要用在刀刃上!
而U-Net++就用在了刀刃上!(自夸,夸得好哇)
高潮(完整的U-Net++出炉!):增加选择性(prune剪枝) (引入deep supervision)
前面说过,虽然作者和另一批人马撞车了,但他有更加优越的地方!
就是增加了选择性,即在测试时,你可以选择用1次下采样?2次下采样?3次下采样?还是4次。
这里被称为剪枝(prune)
先看一下网络结构吧:
可以看到,最上层都加入了1x1conv,并联在一起连接到最后的output
这样就使得,这四种情况都有可能是最后的输出(我认为与YOLOv3的三种grid是相似的思想)
这个1x1 conv后直接连接loss,则可以认为,1x1的conv监控着每一个level(level指下采样次数),即监控每一个U-Net分支的输出。
下面就要谈prune剪枝了。
那么就有三个问题:
①为什么U-Net++可以被剪枝
②如何剪枝
③剪枝的优势
三个问题的解答以及在什么时候剪枝
图中最右侧的四次下采样的“大U型”被剪枝剪掉了
我们来观察剪枝剪掉的部分:
①测试阶段:测试阶段神经网络只有前向传播,没有反向传播,我们可以看到,剪枝部分对其他3种U型的前向传播是没有影响的,即Level4(下面简称L4)对L1、L2、L3的输出是没有影响的。
②训练阶段:训练阶段神经网络既有前向传播,也有反向传播,那我们来看看反向传播,剪枝部分L4对L1,L2,L3的反向传播是有贡献的,即L4可以帮助L1,L2,L3进行训练!
这意味着什么?剪枝的L4不会影响L1,L2,L3的测试输出,却可以对它们的训练产生贡献!
所以剪枝是在测试的时候用的,训练时我自然不愿意拿掉这一贡献,还让我测试时候少了一种测试的结果。
怎么决定剪枝剪多少
接下来就主要是引用知乎文章的内容啦,因为也没什么好总结的啦嘿嘿:
“
如何去决定剪多少,还是比较好回答的。因为在训练模型的时候会把数据分为训练集,验证集和测试集,训练集上是一定拟合的很好的,测试集是我们不能碰的,所以我们会根据子网络在验证集的结果来决定剪多少。所谓的验证集就是一开始从训练集中分出来的数据,用来监测训练过程用的。
先看看L1~L4的网络参数量,差了好多,L1只有0.1M,而L4有9M,也就是理论上如果L1的结果我是满意的,那么模型可以被剪掉的参数达到98.8%。不过根据我们的四个数据集,L1的效果并不会那么好,因为太浅了嘛。但是其中有三个数据集显示L2的结果和L4已经非常接近了,也就是说对于这三个数据集,在测试阶段,我们不需要用9M的网络,用半M的网络足够了。
我们统计了用不同的模型,1秒钟可以分割多少的图。如果用L2来代替L4的话,速度确实能提升三倍。
剪枝应用最多的就是在移动手机端了,根据模型的参数量,如果L2得到的效果和L4相近,模型的内存可以省18倍。还是非常可观的数目。
”
总结
个人认为U-Net++对U-Net的改观有两个方面比较牛掰!
①将长连接改为长短连接并存(ResNet到DenseNet的转变),以及加入了deep supervision使得输出可监控,这也是后面最关键创新————剪枝的结构基础
②剪枝使得原本刻板的网络结构在测试时变得十分灵活,使得网络在测试时的可变性大大增加,flexibility大大增加(很好的利用了不用层级的特征)
作者的简单总结
“简单的总结一下,UNet++的第一个优势就是精度的提升,这个应该它整合了不同层次的特征所带来的,第二个是灵活的网络结构配合深监督,让参数量巨大的深度网络在可接受的精度范围内大幅度的缩减参数量。”