本文介绍一种知识蒸馏的方法(Variational Information Distillation)。
1. 核心思想
作者定义了将互信息定义为:
如上式所述,互信息为 = 教师模型的熵值 - 已知学生模型的条件下的教师模型熵值。
我们有如下常识:当学生模型已知,能够使得教师模型的熵很小,这说明学生模型以及获得了能够恢复教师模型所需要的“压缩”知识,间接说明了此时学生模型已经学习的很好了,也就是说明上式中的H(t|s)
很小,从而使得互信息I(t;s)
会很大。因此,就可以通过最大化互信息的方式来进行蒸馏学习。
如图所示,学生网络与教师网络保持高互信息(MI),通过学习并估计教师网络中的分布,激发知识的传递,使相互信息最大化。
2. 损失函数
由于p(t|s)
难以计算,作者根据IM算法,利用一个可变高斯q(t|s)
来模拟p(t|s)
:
上述公式中的大于等于操作用到了KL
散度的非负性。由于蒸馏过程中H(t)
和需要学习的学生模型参数无关,因此最大化互信息就转换为最大化可变高斯分布的问题。
作者利用一个均值,方差可学习的高斯分布来模拟上述的q(t|s)
:
式子中可学习的方差定义如下:
其中阿尔法c是可学习参数。
class VIDLoss(nn.Module):
"""Variational Information Distillation for Knowledge Transfer (CVPR 2019),
code from author: https://github.com/ssahn0215/variational-information-distillation"""
def __init__(self,
num_input_channels,
num_mid_channel,
num_target_channels,
init_pred_var=5.0,
eps=1e-5):
super(VIDLoss, self).__init__()
def conv1x1(in_channels, out_channels, stride=1):
return nn.Conv2d(in_channels, out_channels,kernel_size=1, padding=0,bias=False, stride=stride)
# 通过一个卷积网络来模拟可变均值
self.regressor = nn.Sequential(
conv1x1(num_input_channels, num_mid_channel),
nn.ReLU(),
conv1x1(num_mid_channel, num_mid_channel),
nn.ReLU(),
conv1x1(num_mid_channel, num_target_channels),
)
# 可学习参数
self.log_scale = torch.nn.Parameter(
np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels)
)
self.eps = eps
def forward(self, input, target):
# pool for dimentsion match
s_H, t_H = input.shape[2], target.shape[2]
if s_H > t_H:
input = F.adaptive_avg_pool2d(input, (t_H, t_H))
elif s_H < t_H:
target = F.adaptive_avg_pool2d(target, (s_H, s_H))
else:
pass
# 均值方差
pred_mean = self.regressor(input)
pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps
pred_var = pred_var.view(1, -1, 1, 1)
# 利用均值和方差可学习的高斯分布来模拟概率
neg_log_prob = 0.5*(
(pred_mean-target)**2/pred_var + torch.log(pred_var)
)
loss = torch.mean(neg_log_prob)
return loss
源代码
参考文献:CVPR 2019 | VID_最大化互信息知识蒸馏