原文链接:https://arxiv.org/abs/2003.04390
源代码链接:https://github.com/yinboc/few-shot-meta-baseline
背景知识:
- meta-learning(元学习)
本质是一种“learning to learn”的学习过程,不同于常用的深度学习模型(依据数据集去学习如何预测或者分类),meta-learning是学习“如何更快学习一个模型”的过程
- MAML算法:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
模型无关框架,定义一种架构,使用meta-learning去训练base-learner。
- base-learner(基学习器):从足够样本的公共类(基类)中训练过的的深度学习模型
问题定义:
小样本分类:给定一个基类Cbase的标记数据集,每个类中有大量的图像,目标是学习新类Cnovel中每个类中有几个样本的概念,其中Cbase∩Cnovel=∅。(N-way指训练数据中有N个类别,K-shot指每个类别下有K个被标记数据)在一个N类K样本的小样本任务中,支持集(基类)包含每个类中具有K样本的N个类,查询集(新类)包含相同的N个类,每个类中有Q个样本,其目标是将查询集中的N个×Q未标记的样本正确分类为N个类。
文章成果:
- 定义了一个Classifier-Baseline(基线分类器),在基类上预先训练一个分类器来学习视觉表示,删除最后一个全连接(FC)层,得到特征映射。然后在少量样本的新类上进行训练,计算新类的样本平均特征,利用特征空间中的余弦距离,用最近质心对查询样本(验证集)进行分类,即为余弦最近质心分类法。只是用于估计新类的最后FC权重,不需要重新训练所有模型参数。
- 在1的基础上,用元学习改进Classifier-Baseline,提出Meta-Beaseline。在Meta-Baseline中,使用预先训练的Classifier-Baseline初始化模型,并使用余弦最近轮廓度量执行元学习,这是Clssififier-Baseline中的评估度量 。(一种基于度量的元学习)
- 文章观察到训练过程中测试性能下降,在元学习的背景下对两种泛化类型进行了评估,基类泛化和新类泛化。
- 文章还对数据集因素对元学习的影响进行了研究。
文章具体研究内容:
- Classifier-Baseline
分类器-基线是指在所有基类上训练一个具有分类功能的分类器,并使用余弦最近质心方法执行小样本任务。具体来说,我们在所有具有标准交叉熵损失的基类上训练一个分类器,然后删除它的最后一个FC层,得到编码器fθ,将输入映射到特征空间。给定一个具有支持集S的小样本任务,让Sc表示c类中的小样本,我们计算平均特征Wc作为c类的质心:
然后,对于小样本任务中的查询样本x,我们预测样本x属于c类的概率作为样本x的特征向量与c类质心之间的余弦相似度:
其中<.,.>表示两个向量的余弦相似度。 请注意,Wc也可以看作是新FC层对新概念的预测权重。
- Meta-Baseline
一般来说,Meta-Baseline包含两个训练阶段。
第一阶段是预训练阶段,即训练Classifier-Baseline(即在所有基类上训练分类器,并删除其最后一个FC层以获得fθ)。
第二阶段是元学习阶段,在元学习阶段同样使用基类中的数据分成多个task,在每个task中对support-set用fθ编码,然后用(1)式求每个类的平均特征表示。同时对query-set也进行编码操作,利用(2)式余弦相似度求query-set和support set之间的距离,使用softmax进行分类。
我们在分类器-基线评估算法上对模型进行优化。具体来说,给定预先训练的特征编码器fθ,我们从基类中的训练数据中采样N类K样本任务(具有N×Q查询样本)
为了计算每个任务的损失,在支持集中,我们在方程1中计算定义的N个类的质心,然后用于方程2中计算定义的查询集中每个样本的预测概率分布。损失是由p和查询集中样本的标签计算的交叉熵损失。 请注意,我们将每个任务视为训练中的数据点,每批可能包含多个任务,并计算平均损失。
未完待续