文章目录
论文主要工作
通过遗传算法进行神经网络架构搜索,论文首先提出了一种能表示神经网络架构的编码方案,在此编码方案上初始化种群,对种群进行选择、变异、交叉,从而抛弃性能差的神经网络架构并产生新的神经网络架构,论文将训练好的架构在验证集上的准确率作为评判网络性能好坏的指标
本博客首先介绍论文中提到的方法,接着阐明搜索空间,最后总结论文实验结果。
遗传算法简介
传统的遗传算法往往具有下列步骤
- 定义个体的基因编码方案
- 初始化种群
- 衡量个体生存竞争能力的适应度(通常是一个函数,函数值表示个体的生存竞争能力)
- 淘汰适应度低的个体,选择适应度高的个体构成种群下一代的成员(选择)
- 按一定概率对下一代成员进行基因的交叉与变异(交叉与变异),产生新个体的基因编码方案
- 评估新种群的适应度
可以看到,遗传算法其实就是模仿生物进化的过程
将遗传算法用于NAS
神经网络架构的基因编码方案
许多常见的state-of-the-art神经网络架构都可以分为几个stage,两个stage之间通过池化操作连接,同个stage内的所有卷积操作都有相同的过滤器数目,基于上述观察,论文得出了神经网络架构的基因编码方案,具体如下:
- 每个神经网络架构由S个阶段构成,阶段与阶段之间通过池化操作连接
- 每个阶段都有Ks个节点组成,这Ks个节点其实是卷积+BN+ReLU的操作,Ks个节点编号并由小到大顺序排序
- 使用有向边连接一个阶段中的节点,每个节点只能连接比自己编号要大的节点,使用Vs,ks表示第s阶段的第ks个节点,其中ks=1,2,…,Ks
在每个阶段,我们使用1+2+3+…+(Ks-1)=21Ks(Ks−1)位来编码阶段内部节点之间的有向边,s阶段编码的第一位表示节点Vs,1和节点Vs,2之间是否有Vs,1指向Vs,2的边连接,第二位与第三位表示V_{s,1}和节点V_{s,3}、V_{s,2}和节点V_{s,3}$之间是否有有向边连接,1表示有有向边连接,0表示没有,以此类推,有边连接表示前一节点输出的特征图会作为后一连接节点的输入,可能会有多个边指向同一个节点,此时element-wise相加,此时是一个残差结构(由于位于一个阶段内部,因此特征图的channel数一致,只需要padding即可相加),一个具体的例子如下:
为了让每一个编码均有效,论文在每个阶段中额外定义了两个默认节点,对于第s个阶段来说,这两个节点分别为Vs,0和Vs,Ks+1,分别为上图中的红色和绿色节点,红色节点接收前一节点的输出,对其进行卷积,并将输出的特征图送往所有没有前驱但是有后继的节点,而绿色节点接受所有有前驱但是没后继节点输出的特征图,对其进行element-wise相加后,输入到池化层,这两个节点不做编码,如果一个阶段的编码为0,则红色节点直接和绿色节点连接,即经过以此卷积操作后立刻进行池化操作。
上述编码方案可以编码一些知名网络架构,因此在搜索空间中,也包含知名的手工设计的网络架构,例如:
算法的超参数
S:阶段个数
Ks(s=1,2,…,S):第s阶段的节点个数
通过定义上述参数,对应的所有编码方案就构成了神经网络架构的搜索空间
初始化
对于每个阶段的编码,通过从伯努利分布(B(0.5))中进行采样,依据编码初始化神经网络架构,对其进行训练,将验证集上的准确率作为适应度。
但是论文发现遗传算法对不同的初始化策略并不敏感,即使所有的个体的编码都是0,遗传算法仍然能够发现有竞争力的架构。
选择
论文使用Russian roulette(具体查看轮盘赌算法)来选择下一代种群的成员,适应度为成员在验证集上的准确率,由于选择过程需要保证种群的大小与上一代一致,所以同一个成员可能被选择多次,导致下一代种群中有多个相同的网络架构,通过选择操作,可以淘汰准确率低的网络架构
变异与交叉
变异
变异操作的单位为基因编码的每一位。
每个个体发生变异的概率为pM。如果发生变异,每一个个体的编码的每一位有qM概率发生翻转(1变为0,0变为1)
交叉
交叉操作的单位为一个阶段的编码,如果发生交叉,那么交换的片段是一个阶段的编码。
种群中的每一对个体(由两个神经网络架构组成)发生交叉的概率是pC。如果发生交叉,两个个体相同阶段发生交叉的概率为qC
评估新种群的适应度
对产生的新种群中每个模型进行重新训练,对于旧模型,采用其历史准确率的平均值作为适应度,对于新模型,采用其准确率作为适应度,这么做是为了对抗训练噪声,即有些模型之所以有较高的准确率,是因为“幸运”,而不是这类模型的架构优秀
算法的流程
实验及其结果
该论文也是选择在小数据集上探索网络架构,在将探索到的网络架构应用于大数据集,相应的结果不在此呈现,论文通过在CIFRA10上运行算法,通过运行两次算法,得到的两个架构如下:
上述网络在ILSVRC2012上训练的结果如下:
虽然发现的网络架构不及某些手工设计的网络架构,但至少证明该算法发现的网络架构具有较好的性能