【学习笔记】xProtoNet和ProtoNet:可解释的图像分类

xProtoNet和ProtoNet:可解释的图像分类

ProtoNet

论文地址:This Looks Like That: Deep Learning for Interpretable Image Recognition
代码地址:https://github.com/cfchen-duke/ProtoPNet

之前的分类模型

之前的分类模型几乎都Encoder-Decoder的形式,通过Encoder获取图像特征向量之后在通过Decoder进行分类。比如下图,encoder一般是预训练的模型,分类模型主要是训练deocder
【学习笔记】xProtoNet和ProtoNet:可解释的图像分类
而所谓的可解释也都是在这个部分之外再添加诸如attention的部分,和人的思考模式还是有点区别的。
我们人的思考可能是这样的模式(原文是用鸟来举例子,我觉得很没代入感,我用个二次元的例子):
【学习笔记】xProtoNet和ProtoNet:可解释的图像分类
我们看到一张图片(左),如果想识别它是不是雷电将军,我们可能会怎么想呢?我们心中肯定有一个雷电将军(右),这个图片说白了是我们心中对于她记忆种特征的一个具象化图。如果和我们心中的将军像到某种程度的话,那她肯定就是将军了。
比如:

  • 左边的刘海和雷电将军的刘海一样;
  • 俩人都用单手剑;
  • 腰部的右侧都有饰品;
  • 俩人的size差不多

所以我们得出结论:芽衣就是雷电将军(误)
以这种思考模式为基础,作者提出了一个模型,其中灰色都不需要训练的或者是预训练的。
【学习笔记】xProtoNet和ProtoNet:可解释的图像分类
不再是简单的Encoder-Decoder的模式,而是在分类前加了一个对比的过程。也就是刚刚的,和每个类的最典型的特征进行相似度的计算,并且最后通过加权计算出应该归属的类。这个prototypes也就是每个类的典型特征。我在最开始看到这篇文章的时候一直纠结在这个prototype是怎么得到的,看了训练过程才慢慢明白。
【学习笔记】xProtoNet和ProtoNet:可解释的图像分类
训练过程基本如图,首先是f,也就是一个训练好的encoder。
输入的图像为x,图像的label为y。
经过f以后得到特征图像z=f(x)。z的形状在论文中是(7,7,128/256/512)。
每个类别都会有m个P,这个m是自己设定的,论文中设定为10。也就是对于一个类来说有十个特征来表示。
训练的过程分为三个部分

最后一层h前的随机梯度下降

最关键的其实是P的学习,我想直接从loss的公式来看会比较好理解
【学习笔记】xProtoNet和ProtoNet:可解释的图像分类
CreEnt也就是比较熟悉的cross entropy loss,在学习P的过程中h是固定的,也就是只通过P的改变来完成识别。
Clst和Sep其实就是两个距离。Clst取正最小,注意min下标是 ∈ \in ∈,也就是寻找特征图像中与该类相似;Sep取负最小其实就是求最大,下标是 ∉ \notin ∈/​,也就是和其他类不相似。
这个loss基本就确定了prototype的意义,也让它的学习成为可能。

prototype的投影

通过流程图可以看到这个prototype是个特征向量啊,电脑能看懂我们可看不懂。我们看懂才是可解释的模型。所以需要把prototype投影到图像上,做到可视化。还是从loss看吧:
【学习笔记】xProtoNet和ProtoNet:可解释的图像分类
看起来好像和上面的差不多,但可以看到min的下标变了,从p变成了z,也就是说之前是寻找p的过程,现在变成了寻找z的过程,也就是所谓的投影。这部分的数学原理看起来就很复杂,说实话我没看进去,只是大概懂了idea。

h层的学习

结果

【学习笔记】xProtoNet和ProtoNet:可解释的图像分类
大概就是这种感觉

xProtoNet

论文链接:XProtoNet: Diagnosis in Chest Radiography with Global and Local Explanations

首先protonet是日常生活场景的图片,而xProtoNet将其迁移到了医疗图像领域,去进行CT图像的分类。
我觉得改的地方不多,但很有效果。主要是针对Patch的设计进行了更改。之前的Patch也可以看到大小是固定的,这就带来了一个问题,这个大小该怎么设定,实际的特征是不是真的就这么大。
比如雷电将军身上的特征和鸡蛋的特征肯定大小(或者说是占整体的比例)是差别很大的。所以固定的大小肯定是不行的。作者就设计了动态变化Patch。
【学习笔记】xProtoNet和ProtoNet:可解释的图像分类
流程图如上,可以看到有两个部分,分别是FeatureMap和OccurrenceMap。OM就是推测的P映射到图片上最有可能出现的区域。文章不再使用单纯的feature而是使用两个的结合来进行比较。(OM的训练过程目前我还有点没看懂)
【学习笔记】xProtoNet和ProtoNet:可解释的图像分类

上一篇:读书笔记-设计模式-可复用版-Prototype 原型模式


下一篇:#es6数组的扩展