小样本学习记录————MAML++
在MAML中,我们了解了经典的元学习框架MAML,虽然MAML训练快效果好,但是仍然存在一些问题,在How to train your MAML中介绍了MAML的缺点和改进方法。
MAML简单回顾
在看MAML存在问题前先简单回顾一下MAML的内循环和外循环更新公式,如果需要详细了解可以看我的上一篇文章。https://blog.csdn.net/yunlong_G/article/details/121858650 。
内循环:对于一个子任务的b的第i次更新表示为:
θ i b = θ i − 1 b − α ∇ θ L S b ( f θ i − 1 b ) \theta_i^b=\theta_{i-1}^b-\alpha\nabla_\theta \mathcal L_{S_b}(f_{\theta_{i-1}^b}) θib=θi−1b−α∇θLSb(fθi−1b)
外循环:对于完成第N个任务后更新参数表示为:
θ 0 = θ 0 − β ∇ θ ∑ b = 1 B L T b ( f θ N b ( θ 0 ) ) \theta_0 = \theta_0 - \beta\nabla_\theta \sum_{b=1}^B \mathcal L_{T_b}(f_{\theta_N^b(\theta_0)}) θ0=θ0−β∇θb=1∑BLTb(fθNb(θ0))
MAML存在的问题
从上图我们可以看到MAML在训练过程中不稳定,存在泛化能力弱的问题,以下分别介绍作者认为的几点原因。
训练不稳定:
模型没有跳跃连接的话,通过多个卷积层,梯度多次被相同的参数相乘,容易出现梯度爆炸和梯度消失。
二阶导数代价
MAML仅使用了一阶梯度,虽然迭代加快但是泛化能力不稳定。经尝试了使用一阶方法的进一步尝试(Nicholl等人,2018年),其中作者在基础模型上应用标准SGD,然后在N步之后从他们的初始化参数向基础模型的参数迈进一步。Reptile的结果各不相同,有的超过MAML,有的低于MAML。
缺少批量归一化统计累计
泛化能力弱的另一原因:使用当前批次的统计信息进行批次归一化,而不是累积运行统计信息。这导致批量归一化效率较低.因为学习的基数必须适应各种不同的均值和标准偏差,而不是单一的均值和标准偏差。另一方面,如果批量归一化使用累积的运行统计数据,它最终将收敛到某个全局平均值和标准差。
共享(跨步骤)批次标准化偏差
在每一个子任务的训练过程中偏差一致,认为所有的子任务都是一样的特征分布,但实际上不会全是一样的分布。
共享内环(跨步和跨参数)学习速率
使用了一个共享的学习率给所有的参数和更新步骤,这导致了两个问题
- 下降方向不一定是特定数据集的方向
- 搜索方法可能花费的时间长。
有为每一个网络单独设置学习率和更新方向,但是又计算工作量和存储器问题。
国定外循环学习率
使用了固定的外循环循环率也会降低MAML算法的泛化性能,优化变慢的原因。
对MAML的改进
梯度不稳定性→多步损耗优化(MSL)
将外循环的更新由原来完成内循环再进行更新 变为——》内循环每进行一步就计算损失,利用每一步损失的加权和更新。
θ = θ − β ∇ θ ∑ b = 1 B ∑ i = 0 N v i L T b ( f θ i b ) \theta = \theta - \beta\nabla_\theta \sum_{b=1}^B \sum_{i=0}^N v_i\mathcal L_{T_b}(f_{\theta_i^b}) θ=θ−β∇θb=1∑Bi=0∑NviLTb(fθib)
- β \beta β表示学习率
- B任务总数
- N内循环更新步数
- v i v_i vi表示第i步的重要程度
- L \mathcal L L表示损失函数
二阶导数代价→导数顺序退火(DA)
MAML中为了提高计算效率使用一阶导数来进行整个训练,这是影响模型泛化的主要原因之一。MAML++作者通过实验发现前五十epoch使用一阶导数,然后再转为使用二阶导数可以得到不错的效果,而且没有梯度爆炸和梯度消失,相比只是用二阶导数更稳定。
缺少批次归一化统计累积→每步批次归一化运行统计(BNRS)
按当前批次的统计信息标准化 变为 ——》按步骤收集统计数据。为了每一步收集运行统计数据,需要实例化网络中每个批归一化层的运行均值和运行标准差集合N组(其中N是内环更新步骤的总数),并根据优化过程中采取的步骤分别更新运行统计数据。每一步的批归一化方法应该可以加速MAML的优化,同时潜在地提高泛化性能。
共享(跨步骤)批次归一化偏差→每步批次归一化权重和偏差(BNWB)
所有子任务共用一个偏差 变为——》内循环更新过程中每一步学习一组偏差。
共享内环学习速率(跨步和跨参数)逐层→学习逐步学习率和梯度方向(LSLR)
一样的学习率 变为——》 为网络中的每一层学习一个学习速率和方向,以及在执行步骤时学习基础网络的每个自适应的不同的学习速率。学习每个层的学习速率和方向,而不是每个参数,应该会减少所需的内存和计算量,同时在更新步骤中提供额外的灵活性。此外,对于学习到的每个学习率,将有N个该学习率的实例,每个步骤对应一个实例。通过这样做,参数可以*地学习降低每一步的学习率,这可能有助于减轻过拟合。
固定外环学习率→余弦退火元优化器学习率(CA)
事实证明:余弦函数(LoshChilov&Hutter,2016)退火学习率在具有更高泛化能力的学习模型中是至关重要的。MAML++将余弦退火应用到元学习中,来代替有原来的固定学习率。
结果比对
- MSL可用于提升MAML的稳定性以及泛化性能,但会降低算法速率以及增加计算资源消耗。
- DA可以增加MAML训练效率以及缓解梯度爆炸、消失问题。
- LSLR可以减少Meta-SGD给MAML算法带来增加的存储消耗问题,同时可以有一个可学习的内循环学习率 α \alpha α。
- CA可以为MAML带来动态调整的外更新学习率 β \beta β,可以帮助算法缓解次陷入次优解的问题。
- BNWB+BNRS可以提升MAML的泛化能力。
从下表中可以看出,上文提出的几种措施都对模型效果有了提高。MAML++在mini-imagenet数据集上内部一次循环就别其他方法都好了。