本系列文章介绍一些知识蒸馏领域的经典文章。
知识蒸馏:提取复杂模型有用的先验知识,并与简单模型特征结合算出他们的距离,以此来优化简单模型的参数,让简单模型学习复杂模型,从而帮助简单模型提高性能。
1. Attention Transfer原理
论文Paying more attention to attention主要通过提取复杂模型生成的注意力图来指导简单模型,使简单模型生成的注意力图与复杂模型相似。这样,简单模型不仅可以学到特征信息,还能够了解如何提炼特征信息。使得简单模型生成的特征更加灵活,不局限于复杂模型。
其中,图a是输入,b是相应的空间注意力图,它可以表现出网络为了分类所给图片所需要注意的地方。所谓空间注意力图,其实就是将特征图[C, H , W]
通过映射变换成特征[H, W]
。作者将每层通道平方相加获得特征图对应的注意力图。
上图是人脸识别任务中,对不同维度的特征图进行变换求得的注意力图,可以发现高维注意力图会对整个脸作出反应。
2. 损失函数
论文中,作者将损失分为两部分:
第一部分是分类损失是简单的交叉熵损失函数,作用是实现分类。
第二部分是衡量复杂模型于简单模型注意力图差异的函数。首先注意力图进行归一化,即除以自身的模值,然后计算两个注意力图的p范数。
第三部分也是衡量复杂模型于简单模型注意力图差异的函数,在实际代码实现中使用KL散度实现。KL散度是用来衡量两个概率分布之间的相似性的函数。不了解KL散度的见这里。
3. 代码解读
参考官方代码,以下是使用attention transfer
技巧的关键部分代码。
def f(inputs, params, mode):
'''
网络结构中使用到的函数
返回:
y_t: 学生模型输出结果
y_t: 老师模型输出结果
loss: 学生输出和老师输出结果归一化后的欧式距离,即第三部分loss
'''
# f_s和f_t分别是定义的学生和老师网络的网络
y_s, g_s = f_s(inputs, params, mode, 'student.')
with torch.no_grad():
y_t, g_t = f_t(inputs, params, False, 'teacher.')
return y_s, y_t, [utils.at_loss(x, y) for x, y in zip(g_s, g_t)]
def distillation(y, teacher_scores, labels, T, alpha):
'''
arguments:
y: 学生模型归一化后的输出
teacher_scores:老师模型归一化后的输出
labels:学生模型的标签
T, alpha:超参数
returns:
loss: 简单交叉熵函数和KL函数的加权和,即前两部分loss
'''
# 学生网络软化后结果
p = F.log_softmax(y/T, dim=1)
# 老师网络软化后结果
q = F.softmax(teacher_scores/T, dim=1)
# 两个模型之间的距离损失
l_kl = F.kl_div(p, q, size_average=False) * (T**2) / y.shape[0]
# 学生模型的分类损失
l_ce = F.cross_entropy(y, labels)
return l_kl * alpha + l_ce * (1. - alpha)
def at(x):
'''归一化'''
return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))
def at_loss(x, y):
'''距离函数'''
return (at(x) - at(y)).pow(2).mean()
def h(sample):
'''网络结构'''
inputs = utils.cast(sample[0], opt.dtype).detach()
targets = utils.cast(sample[1], 'long')
if opt.teacher_id != '':
# 分配到几个GPU上并行训练
y_s, y_t, loss_groups = utils.data_parallel(f, inputs, params, sample[2], range(opt.ngpu))
# 欧式距离
loss_groups = [v.sum() for v in loss_groups]
[m.add(v.item()) for m, v in zip(meters_at, loss_groups)]
# 总的loss
return utils.distillation(y_s, y_t, targets, opt.temperature, opt.alpha) \
+ opt.beta * sum(loss_groups), y_s
engine.train(h, train_loader, opt.epochs, optimizer)
论文理解部分参考文献:
知识蒸馏论文详解之:PAYING MORE ATTENTION TO ATTENTION