这篇文章初看有些迷惑,用模型来训练一个source-free image translation(SFIT)方法。
仔细看方法才知道是怎么回事:
实际上这个image translation的训练有点像perceptual loss那篇,generator完成图像翻译,后面虚线的部分是固定的,用来计算loss。
这个source free是怎么来的呢?意思是完成这个训练不需要任何来自source domian的图像。Source CNN是在源域上训练的,target CNN是通过unsupervised domain adaptation学的。
这样就完了?事情当然没有这么简单。实际上作者是通过SFIT来看一看DA的过程中神经网络到底学到了什么知识,适应到了什么内容。这种source free的方式避免了generator接触到source domain 的图像,保证只从两个模型学到知识,避免了image本身带来的“分心”。
这也是一个知识蒸馏问题。作者表示,不同于data-free的知识蒸馏,SFIT必须可以表现出有哪些知识通过DA被转移过去了,而data-free只需要生成满足teacher的数据。
作者的目的就是把两个模型之间的知识差异通关generator来表现出来,并且可视化出来。
右边的knowledge distllation loss 就是用同样的classifier求出两个label 的KL散度,不需要额外解释。
Relationship preserving loss是作者提出来的。实际上就是对两个图像特征先求GRAM矩阵,然后对GRAM矩阵做L2标准化,最后计算MSE loss。
为什么用这种row-wise normalized Gram matrix 而不是原来的Gram matrix?(这里的row实际上是channel,因为F是D*HW的)
这个loss更关注channel之间的相对关系,而不是传统的用绝对值的方式描述channel自相关的绝对值。
这是对应的两个gram矩阵。更深的颜色代表更大的差异,因此带来的监督效果更强。传统的gram矩阵只有四个角落的通道的监督更强,而作者提出的更加均匀一些。
实验部分:
训练的细节:generator是初始化为一个transparent filter,输入什么输出就是什么。具体来说就用identity loss和content loss训练的。
实验结果:
首先是定量分析,作者展示了在生成的图像的分类准确率。
在数字数据集上的表现:
在Office-31数据集上的表现:
在VisDA数据集上的表现:
source model一行应该是指source model在target domain的分类准确率,target model一行是target model在target domain上的准确率。
由于两个模型在知识上存在差异,所以在target domain上的准确率也有较大的差异。
generated images一行是指将target domain域的图像翻译回source domain后送入source model的分类准确率。
Fine-tuning一行是使用target images和generated images对target model微调的结果。
作者对此的解读是,generated images缓解了分类性能上的知识差异,表示SFIT方法确实把target model适应到的、学到的知识蒸馏到了generator中。
定型分析:
VisDA数据集的结果:SYN→REAL
数字数据集(SVHN→MNIST)的结果:
注意,MNIST的图像是灰度图像,因此生成的图像中随机添加了颜色。
Office-31数据集的结果:Amazon→Webcam
结论就是generator会保留target image的内容,并且可以学到从来没见过的、来自source domain的style。
数字的数据集使用的是LeNET,loss比较容易优化而不容易学到style。
另外两种数据集使用的是resnet,生成满足loss的图像是比较困难的,因此训练的时间和更严格的限制导致生成图像的质量也更好。
使用不同的UDA方法在VisDA数据集上应用SFIT的结果:
其中a是target image,bcd分别用了DAN、ADDA、SHOT-IM。更强的UDA方法可以更好的把没见过的source domain的风格迁移到target image上。
小小总结:作者提出的SFIT是可视化DA过程中source model与target model知识差异的好工具。
从image translation的角度来看,这个结构并不是新颖的。但是与DA和知识蒸馏相结合便达到了另一层高度,该说不愧是cvpr吗。