蒸馏论文二(PAYING MORE ATTENTION TO ATTENTION)

本系列文章介绍一些知识蒸馏领域的经典文章。

知识蒸馏:提取复杂模型有用的先验知识,并与简单模型特征结合算出他们的距离,以此来优化简单模型的参数,让简单模型学习复杂模型,从而帮助简单模型提高性能。

1. Attention Transfer原理

论文Paying more attention to attention主要通过提取复杂模型生成的注意力图来指导简单模型,使简单模型生成的注意力图与复杂模型相似。这样,简单模型不仅可以学到特征信息,还能够了解如何提炼特征信息。使得简单模型生成的特征更加灵活,不局限于复杂模型。

蒸馏论文二(PAYING MORE ATTENTION TO ATTENTION)
其中,图a是输入,b是相应的空间注意力图,它可以表现出网络为了分类所给图片所需要注意的地方。所谓空间注意力图,其实就是将特征图[C, H , W]通过映射变换成特征[H, W]。作者将每层通道平方相加获得特征图对应的注意力图。

蒸馏论文二(PAYING MORE ATTENTION TO ATTENTION)
上图是人脸识别任务中,对不同维度的特征图进行变换求得的注意力图,可以发现高维注意力图会对整个脸作出反应。

蒸馏论文二(PAYING MORE ATTENTION TO ATTENTION)

2. 损失函数

论文中,作者将损失分为两部分:蒸馏论文二(PAYING MORE ATTENTION TO ATTENTION)
第一部分是分类损失是简单的交叉熵损失函数,作用是实现分类。

第二部分是衡量复杂模型于简单模型注意力图差异的函数。首先注意力图进行归一化,即除以自身的模值,然后计算两个注意力图的p范数。

第三部分也是衡量复杂模型于简单模型注意力图差异的函数,在实际代码实现中使用KL散度实现。KL散度是用来衡量两个概率分布之间的相似性的函数。不了解KL散度的见这里
蒸馏论文二(PAYING MORE ATTENTION TO ATTENTION)

3. 代码解读

参考官方代码,以下是使用attention transfer技巧的关键部分代码。

def f(inputs, params, mode):
    '''
    网络结构中使用到的函数
    返回:
        y_t: 学生模型输出结果
        y_t: 老师模型输出结果
        loss: 学生输出和老师输出结果归一化后的欧式距离,即第三部分loss
    '''
    # f_s和f_t分别是定义的学生和老师网络的网络
    y_s, g_s = f_s(inputs, params, mode, 'student.')
    with torch.no_grad():
        y_t, g_t = f_t(inputs, params, False, 'teacher.')
    return y_s, y_t, [utils.at_loss(x, y) for x, y in zip(g_s, g_t)]

def distillation(y, teacher_scores, labels, T, alpha):
    '''
    arguments:
	    y: 学生模型归一化后的输出
	    teacher_scores:老师模型归一化后的输出
	    labels:学生模型的标签
	    T, alpha:超参数
	returns:
		loss: 简单交叉熵函数和KL函数的加权和,即前两部分loss
    '''
    # 学生网络软化后结果
    p = F.log_softmax(y/T, dim=1)

    # 老师网络软化后结果
    q = F.softmax(teacher_scores/T, dim=1)

    # 两个模型之间的距离损失
    l_kl = F.kl_div(p, q, size_average=False) * (T**2) / y.shape[0]

    # 学生模型的分类损失
    l_ce = F.cross_entropy(y, labels)

    return l_kl * alpha + l_ce * (1. - alpha)

def at(x):
    '''归一化'''
    return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))

def at_loss(x, y):
    '''距离函数'''
    return (at(x) - at(y)).pow(2).mean()

def h(sample):
    '''网络结构'''
    inputs = utils.cast(sample[0], opt.dtype).detach()
    targets = utils.cast(sample[1], 'long')
    if opt.teacher_id != '':
    	# 分配到几个GPU上并行训练
        y_s, y_t, loss_groups = utils.data_parallel(f, inputs, params, sample[2], range(opt.ngpu))
        # 欧式距离
        loss_groups = [v.sum() for v in loss_groups]
        [m.add(v.item()) for m, v in zip(meters_at, loss_groups)]

        # 总的loss
        return utils.distillation(y_s, y_t, targets, opt.temperature, opt.alpha) \
               + opt.beta * sum(loss_groups), y_s

engine.train(h, train_loader, opt.epochs, optimizer)

论文理解部分参考文献:
知识蒸馏论文详解之:PAYING MORE ATTENTION TO ATTENTION

上一篇:Attention Is All You Need


下一篇:paper 4:Attention is all you need