代码:
github.com/cbfinn/maml
github.com/cbfinn/maml_rl
Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
Abstract
我们提出了一种与模型无关的元学习算法,因为它与任何用梯度下降训练过的模型兼容,适用于各种不同的学习问题,包括分类、回归和强化学习。元学习的目标是在各种学习任务上训练一个模型,这样它就可以用少量的训练样本解决新的学习任务。在我们的方法中,模型的参数被明确地训练,这样一个新的任务只用少量的梯度步长和少量的训练数据就能产生良好的泛化性能。实际上,我们的方法使模型易于微调。我们证明了该方法在两个few-shot图像分类(即每类的图像比较少)基准上取得了最先进的性能,在few-shot回归上产生了良好的结果,并加速了神经网络策略梯度增强学习的微调。
1. Introduction
快速学习是人类智能的一个标志,无论是从几个例子中识别物体,还是在短短几分钟的经历后快速学习新技能。我们的人工代理应该能够做同样的事情,即从少数例子中快速学习和适应,并随着越来越多的数据可用而能够继续适应。这种快速而灵活的学习具有挑战性,因为代理必须将其以往的经验与少量的新信息相结合,同时避免对新数据的过拟合。此外,以往经验和新数据的形式将取决于任务。因此,为了获得最大的适用性,学习(或元学习)的机制应该对任务和完成任务所需的计算形式具有普遍性。
在这项工作中,我们提出了一个通用的和模型无关的元学习算法,因为它可以直接应用于任何学习问题和用梯度下降进程训练的模型。我们的重点是深度神经网络模型,但我们说明了我们的方法是如何以最小的修改来轻松处理不同的架构和不同的问题设置的,包括分类,回归,和政策梯度强化学习。在元学习中,经过训练的模型的目标是从少量的新数据中快速学习新的任务,经过元学习者的训练,模型能够对大量不同的任务进行学习。我们的方法的关键思想是训练模型的初始参数,这样当参数使用新任务的少量数据计算的一个或多个梯度步骤更新后,模型在新任务上有最优的性能。不同于之前学习更新函数或学习规则的元学习方法(Schmidhuber, 1987; Bengio et al.1992; Andrychowicz et al.2016; Ravi & Larochelle, 2017),我们的算法不扩大学习参数的数量,在模型架构上也没有设置限制(例如,要求周期模型(Santoro et al ., 2016)或Siamese网络(Koch, 2015)),它很容易与全连接层,卷积或循环神经网络结合。它也可以用于各种损失函数,包括可微监督损失和不可微强化学习目标。
从特征学习的角度来看,训练一个模型的参数以使几个梯度步骤,甚至是单个梯度步骤在一个新任务上产生好的结果的过程,可以看作是建立一个广泛适用于许多任务的内部表征。如果内部表征适用于许多任务,只需稍微微调参数(例如,在前馈模型中主要修改顶层权重)就可以产生良好的结果。实际上,我们的程序优化了易于快速调整的模型,允许适应发生在快速学习的正确空间。从一个动态系统的观点来看,我们的学习过程可以被看作是最大化新任务的损失函数对参数的敏感性:当敏感性很高时,对参数的微小局部变化可以导致任务损失的巨大改善。
这项工作的主要贡献是一个用来训练模型的参数的简单的模型和任务不确定的元学习算法,这样少量的梯度更新将导致对新任务的快速学习。我们演示了不同模型类型的算法,包括全连接和卷积网络,并在几个不同的领域应用,包括few-shot回归,图像分类,和强化学习。我们评估表明,元学习算法与最先进的专门用于监督分类的one-shot学习方法对比,其使用了更少的参数,也可以应用于回归并且在任务变化时可以加速强化学习的任务可变性,大大优于作为初始化直接预训练。
2. Model-Agnostic Meta-Learning
我们的目标是训练能够实现快速适应的模型,这是一种经常被形式化为few-shot学习的问题设置。在本节中,我们将定义问题设置并给出算法的一般形式。
2.1. Meta-Learning Problem Set-Up
少量元学习的目标是训练一个模型,使其能够使用少量的数据点和训练迭代数来快速适应新的任务。为此,模型或学习者在元学习阶段对一组任务进行训练,这样经过训练的模型只需使用少量的例子或试验就能快速适应新的任务。实际上,元学习问题把整个任务当作训练样本。在本节中,我们以一般的方式将元学习问题设置形式化,包括不同学习领域的简单例子。我们将在第3节详细讨论两个不同的学习领域。
我们考虑了一个模型, 记为f,它将观察值x映射到输出值a。在元学习过程中,该模型被训练成能够适应大量或无限数量的任务。由于我们想将我们的框架应用于各种各样的学习问题,从分类到强化学习,我们在下面介绍一个学习任务的一般概念。形式上来说,每个任务包含一个损失函数L、一个初始观测值q(X1)的分布、一个转换分布q(Xt+1 |Xt, at)以及一个eisode长度H。在监督学习问题中,长度H=1。模型可能在每个时间t通过选择输出at生成长度H的样本。该损失提供特定任务的反馈,其可能是误分类损失或马尔科夫决策过程中的代价函数的形式。
在元学习的场景中,我们考虑一个我们希望我们的模型能够适应的任务的分布(即这里面有多个不同的任务,用来训练模型)。在K-shot学习设置中,模型被训练去学习一个来自任务分布的新任务,该新任务中的训练数据仅有K个来自qi的样本,其反馈由生成。在元训练期间,一个任务从中采样得到,模型使用K个样本训练,并从来自的对应损失中得到反馈,然后在来自的新样本(即非前面使用的K个样本的别的样本)中进行测试。然后通过考虑在来自qi的新数据上的测试error如何根据参数进行改变来改善模型f。实际上,在采样的任务上的测试error将作为元学习过程中的训练error。在元训练的结尾,从中采样新任务,从K个样本中学习后,元性能将用模型的性能来测量。通常,在元训练期间,用于元测试的任务会被搁置(即在训练时不会被使用)。
2.2. A Model-Agnostic Meta-Learning Algorithm
与之前的研究相反,之前的研究试图训练能摄取整个数据集的递归神经网络(Santoro et al., 2016;Duan et al.,2016b)或特征嵌入,在测试时可与非参数方法结合(Vinyals et al.,2016;(Koch, 2015),我们提出了一种方法,可以通过元学习学习任何标准模型的参数,从而为模型的快速适应做好准备。这种方法背后的直觉是一些内部表征比其他的更容易转移。例如,神经网络可以学习广泛适用于中所有任务的内部特征,而不是单个任务。我们如何鼓励这种通用表征的出现?我们对这个问题采取一个明确的方法:由于该模型将在一个新的任务上使用一个基于梯度的学习规则来进行微调,我们将致力于以基于这种梯度学习规则可以快速优化来自的新任务的这样一种方式来学习模型,并没有过拟合的出现。实际上,我们的目标是找到对任务中变化敏感的模型参数,这样当变化的方向在损失的梯度方向上时,小的参数的变化将在来自的任何任务的损失函数上产生大的改进(参见图1)。
我们没有对模型的形式做任何假设,除了假设它是由一些参数向量Θ参数化的,并且损失函数在Θ中是足够平滑的,这样我们可以使用基于梯度的学习技术。
形式上,我们考虑使用一个带有参数Θ的参数化函数来表示模型。当适应了一个新任务,模型的参数将更新为。在我们的方法中,更新的参数向量是使用一个或多个在任务上的梯度下降更新方法计算得到的。比如,当使用一个梯度更新时如下式子所示:
步长α可以被固定为一个超参数或者是可进行元学习得到。为了简化概念,我们在下面的section中都将只考虑一次梯度更新,但是使用多个梯度更新时一种最直接的扩展方式
模型参数通过优化的性能来训练,其中Θ遍及在从中采样得到的任务中。更准确的是,元目标函数表示如下:
请注意,元优化是在模型参数Θ中执行的,而目标函数是使用更新的模型参数进行计算的。实际上,我们所提出的方法旨在优化模型参数,使新任务上的一个或少量梯度步骤将在该任务上产生最大的有效效果。
任务的元优化是通过随机梯度下降法(SGD)实现的,模型参数Θ更新如下所示:
其中表示β步长。整个算法被概述在算法1中:
所以整个的训练过程就是:
1.首先随机初始化整个模型的参数Θ
2.然后从任务分布随机选取batch_size个任务
3.然后循环训练这batch_size个任务
- 首先训练第一个(假设是猫狗分类任务),其中有每个类有K个样本(所以叫做K-shot,即猫和狗的训练数据都只有K张图片)
- 使用这些数据去训练模型,然后使用梯度下降方法去更新参数
- 接着再循环回3去训练另一个任务,再更新参数,直至batch_size个任务都训练完
4.最后使用最终的模型对每个任务都再进行一次梯度下降更新参数
5.然后再循环至2根据设置迭代次数再继续训练
MAML元梯度更新涉及到一个贯穿梯度的梯度。在计算上,这需要额外的通过f的后向传播来计算Hessian-vector乘积,这一点得到了诸如TensorFlow等标准深度学习库的支持(Abadi等,2016)。在我们的实验中,我们还包括了一个与丢弃这个后向传播和使用一个一阶近似的操作比较,我们将在5.2节中讨论。
3. Species of MAML
在本节中,我们将讨论用于监督学习和强化学习的元学习算法的具体实例。这些域在损失函数的形式以及数据如何由任务生成并呈现给模型方面有所不同,但在这两种情况下可以应用相同的基本适应机制。
3.1. Supervised Regression and Classification
few-shot学习在监督任务领域得到了充分的研究,其目标是通过使用该任务的少量输入/输出对来学习一个新函数,并使用来自类似任务的先前数据进行元学习。例如,目标可能是使用的是之前已经看到过许多其他类型对象的模型,在只看到一个或几个Segway示例之后对Segway的图像进行分类。同样地,在few-shot回归中,目标是在对许多具有类似统计特性的函数进行训练后,从该函数采样的少数数据点预测连续值函数的输出。
为了在2.1节中的元学习定义上下文中公式化回归和分类问题,我们可以定义horizon H = 1,并下放时间下标如xt, 因为模型接受一个输入,产生一个输出,而非输入和输出序列。任务从qi中生成K个观测值x,任务损失用模型对x的输出与该观测和任务对应的目标值y之间的误差表示。
用于监督分类和回归的两个常见损失函数是交叉熵和均方误差(MSE),我们将在下面描述;不过,也可以使用其他监督损失函数。对于使用均方误差的回归任务,损失形式为:
其中表示从任务中采样的输入/输出对。在K-shot回归任务中,为每个任务提供K个输入/输出对用于学习
同样的,对于离散的带有交叉熵损失的分类任务,其损失如下所示:
根据往常的术语,K-shot分类任务在每个类中使用K个输入/输出对,对于N个类的分类任务,总共需要NK个数据点。给定任务的分布,这些损失函数能够被直接插入到Section 2.2的等式中去实现元学习,如下面的算法2所示:
3.2. Reinforcement Learning
在强化学习(RL)中,few-shot元学习的目标是使agent仅使用测试设置中的少量经验就能快速获得一个新测试任务的策略。一项新任务可能涉及实现一个新目标,或者在一个新的环境中实现一个之前训练过的目标。例如,一个代理可能学会快速找出如何在迷宫中导航,这样,当面对一个新的迷宫时,它就可以通过少量样本确定如何可靠地到达出口。在本节中,我们将讨论如何将MAML应用于RL的元学习。
每一个强化学习任务包含一个初始状态分布和一个转换分布,损失与(负)奖励函数R相关。因此整个网络是一个带有Horizon H的马尔可夫决策过程(MDP),其中学习器被允许去查询有限数量的样本轨迹用于few-shot学习。在中,MDP的任何方面都可能在不同任务之间发生变化。正在学习的模型fθ使用的策略是在每个时间step 上将状态Xt映射到actions at的分布上。任务的损失和模型fΦ的形式如下:
在K-shot强化学习中,K个rollouts来自fθ和任务,和相应的奖励可能会用于适应新任务
由于未知的动态,预期奖励通常是不可微的,因此我们使用策略梯度方法来估计模型梯度更新和元优化的梯度。因为策略梯度是on-policy的算法,在fθ的适应期间,每个额外的梯度step需要来自当前策略的新样本。我们在上面的算法3详细说明了算法。该算法与算法2有着相同的结构,主要的不同在step 5和8,其需要从相关于任务的环境中采样轨迹。该方法的实际实现可能会使用目前提出的用于策略梯度算法一系列改进,包括状态或行为独立的基线和trust regions(Schulman et al., 2015).
4. Related Work
忽略
5. Experimental Evaluation
...
5.2. Classification
...
在MAML中,当通过元目标中的梯度算子反向传播元梯度时,使用二阶导数会消耗大量的计算量(见方程(1))。在MiniImagenet上,我们与MAML的一阶近似进行了比较,其中省略了这些二阶导数。值得注意的是,结果方法仍然在后向更新参数值后计算元梯度,用于高效的元学习。令人惊讶的是这两种方法(即一个省略了二阶导数,一个没省略)的性能几乎是一样的,获得完整的二阶导数,这表明大部分的改善MAML来自后向更新参数值时目标函数的梯度,而不是二阶更新。过去的研究发现ReLU神经网络在局部几乎是线性的(Goodfellow et al., 2015),这表明二阶导数在大多数情况下可能接近于零,这部分解释了一阶近似的良好性能。这种近似消除了在额外的向后传递中计算Hessian-vector内积的需要,我们发现这样可以使网络计算速度提高大约33%。
计算步骤像是:
详情可见
https://www.bilibili.com/video/BV1pQ4y1K7cw?p=36
李宏毅机器学习—进阶部分
然后省略掉二阶部分:
可见就简化成了只用考虑i=j的情况:
meta learning - Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks - 1 - 论文学习