无标签学习的知识蒸馏——Learning Student Networks in the Wild

网上看到一篇关于这个论文的博客,居然还要花钱订阅,这就不能忍了,中国人不能欺负中国人。所以自己写博客,论文其实很好懂

论文地址

Learning Student Networks in the Wild (thecvf.com)

代码地址

GitHub - huawei-noah/Efficient-Computing: Efficient-Computing

整体流程

无标签学习的知识蒸馏——Learning Student Networks in the Wild

 

主要问题

使用未标记的数据完成知识蒸馏,解决teacher 网络中训练集不可用的问题。

创新点

1、Noisy adaptation matrix Q

2、提出DFND模型

主要方法

预备知识——知识蒸馏中学生网络的损失函数

无标签学习的知识蒸馏——Learning Student Networks in the Wild

无标签学习的知识蒸馏——Learning Student Networks in the Wild是训练数据, 无标签学习的知识蒸馏——Learning Student Networks in the Wild是gt label, 无标签学习的知识蒸馏——Learning Student Networks in the Wild无标签学习的知识蒸馏——Learning Student Networks in the Wild是交叉熵损失, 无标签学习的知识蒸馏——Learning Student Networks in the Wild是KL距离,无标签学习的知识蒸馏——Learning Student Networks in the Wild是调和参数

等号右边的第二项是最小化老师和学生网络的距离,可以看作是为了帮助训练学生网络而做的强正则化器

传统的知识蒸馏需要原始的数据集去训练老师网络,但是这些数据集有时是不可用的,虽然有一些利用老师网络产生图像的data-free compression 方法,但是现有的方法的性能受限于视觉质量和计算成本。为解决以上问题,最直接的方法是使用公开的无标签数据进行知识蒸馏

无标签学习的知识蒸馏——Learning Student Networks in the Wild

无标签学习的知识蒸馏——Learning Student Networks in the Wild:无标签数据

然而和公式(1)相比公式(2) 有两个缺点:第一,公式(2)的目标是在无标签数据集无标签学习的知识蒸馏——Learning Student Networks in the Wild上最小化老师和学生网络的 ouput 距离,而不是在原始的数据集无标签学习的知识蒸馏——Learning Student Networks in the Wild上, 通过公式(2)训练出来的学生网络 原始数据集上的性能是无法保证的。 第二、由于无标签数据没有真实标签所以 分类损失 无法计算。由于这两部分使公式(2)无法获得合适的目标函数 使得学生网络在原始数据集上的性能可以接受,所以提出了DFND算法,帮助学生网络从老师网络那里学到有用且正确的信息。简单说,从无标签数据中选择最有价值的样本,供学生网络学习使学生网络在原始数据集上也能有良好的性能,并为无标签数据标注伪标签,使公式(2)中有分类损失。

1、数据收集

目的是在巨大的无标签数据中收集有用的数据,以保证训练出来的学生网络在原始数据集上也有良好的性能,目标可以被公式化为 以下公式 

无标签学习的知识蒸馏——Learning Student Networks in the Wild

无标签学习的知识蒸馏——Learning Student Networks in the Wild原始的训练数据, 无标签学习的知识蒸馏——Learning Student Networks in the Wild是由公式(2)选择的无标签数据无标签学习的知识蒸馏——Learning Student Networks in the Wild训练出来的学生网络。然后公式(3)描述的样本选择原则 在原始数据无标签学习的知识蒸馏——Learning Student Networks in the Wild不可用的情况下是很棘手的。所以提出了一个可替代的原则。

为了替代公式(3)中的KL距离,将会分析老师和学生网络输出的MSE loss(L2距离)去收集有用数据

无标签学习的知识蒸馏——Learning Student Networks in the Wild

借助公式(4),在命题1中提出了选择样本的替代原则 

 命题1:

给定一个预训练的教师网络无标签学习的知识蒸馏——Learning Student Networks in the Wild和一个巨大的无标签数据集无标签学习的知识蒸馏——Learning Student Networks in the Wild,无标签样本的噪声值可以被定义为

无标签学习的知识蒸馏——Learning Student Networks in the Wild

 无标签学习的知识蒸馏——Learning Student Networks in the Wild是被老师网络预测的 无标签学习的知识蒸馏——Learning Student Networks in the Wild的伪标签,有用的样本具有较小的噪声值。

证明:(有时间在补上)

根据命题1,老师网络中具有高置信度的样本更有可能被选为训练数据。数据收集方法背后的直觉很简单。首先,使用原始数据集训练的老师网络,置信分数较低的样本在原始分布中有较低的概率值。因此,命题1可以防止选择大部分分布外的样本。更重要的是,老师提供的关于置信度较高的样本的信息不太可能是错误的。因此,命题1中选择值比较大的样本。应该注意的是,尽管在半监督学习中有时会使用低熵伪标签选择未标记数据[19,35],但我们是第一个应用的该技术适用于无数据知识提取环境。此外,我们还进行了深入的分析,从理论上保证了这种数据采集方法的有效性。

 2、无标签数据的噪声蒸馏

使用老师网络中产生的伪标签作为one-hot labels 根据公式(1)来生成标签。知识蒸馏中的损失函数被定义为

无标签学习的知识蒸馏——Learning Student Networks in the Wild

 无标签学习的知识蒸馏——Learning Student Networks in the Wild是训练数据, 无标签学习的知识蒸馏——Learning Student Networks in the Wild,对于知识蒸馏中的无标签学习的知识蒸馏——Learning Student Networks in the Wild 学生可以使用未标记的数据直接向老师学习, 对于交叉熵项,更准确的标签y肯定有助于学生网络的训练。然而,教师预测的标签yˆ是噪声标签,因为无标签学习的知识蒸馏——Learning Student Networks in the Wild没有学习来捕获未标记数据中的信息。为了解决这个问题,我们利用下面的函数来发现噪声标签和真实标签之间的概率,

无标签学习的知识蒸馏——Learning Student Networks in the Wild

 k是y类别数,无标签学习的知识蒸馏——Learning Student Networks in the Wild是噪声适应矩阵Q,这里无标签学习的知识蒸馏——Learning Student Networks in the Wild 类别真实类别为j的时候伪标签为i的概率。它将真实标签的概率无标签学习的知识蒸馏——Learning Student Networks in the Wild转换为噪声概率无标签学习的知识蒸馏——Learning Student Networks in the Wild,因此可以通过在学生网络的softmax 层之后增加噪声适应矩阵Q来学习真实标签。然后计算转换后的输出和伪标签之间的交叉熵损失。公式(3)中的噪声蒸馏可以被定义为

无标签学习的知识蒸馏——Learning Student Networks in the Wild

实际上,Q的真实值是不知道的,我们根据老师网络的先验知识来初始化Q,定义老师网络在原始数据集上第i类的准确率为无标签学习的知识蒸馏——Learning Student Networks in the Wild, Q的对角线值可以被定义为无标签学习的知识蒸馏——Learning Student Networks in the Wild, 因此Q的行可以被视为概率分布

无标签学习的知识蒸馏——Learning Student Networks in the Wild

因此Q被初始化为

无标签学习的知识蒸馏——Learning Student Networks in the Wild

 

算法流程

无标签学习的知识蒸馏——Learning Student Networks in the Wild

 

上一篇:「李宏毅机器学习」机器学习介绍


下一篇:Q-learning++ DQN系列论文小梳理