对比损失的PyTorch实现详解

对比损失的PyTorch实现详解

本文以SiT代码中对比损失的实现为例作介绍。

论文:https://arxiv.org/abs/2104.03602
代码:https://github.com/Sara-Ahmed/SiT

对比损失简介

作为一种经典的自监督损失,对比损失就是对一张原图像做不同的图像扩增方法,得到来自同一原图的两张输入图像,由于图像扩增不会改变图像本身的语义,因此,认为这两张来自同一原图的输入图像的特征表示应该越相似越好(通常用余弦相似度来进行距离测度),而来自不同原图像的输入图像应该越远离越好。来自同一原图的输入图像可做正样本,同一个batch内的不同输入图像可用作负样本。如下图所示(粗箭头向上表示相似度越高越好,向下表示越低越好)。
对比损失的PyTorch实现详解

论文中的公式

l c o n t r x i , x j ( W ) = e s i m ( S i T c o n t r ( x i ) , S i T c o n t r ( x j ) ) / τ ∑ k = 1 , k ≠ i 2 N e s i m ( S i T c o n t r ( x i ) , S i T c o n t r ( x k ) ) / τ                    ( 1 ) l^{x_i,x_j}_{contr}(W)=\frac{e^{sim(SiT_{contr}(x_i),SiT_{contr}(x_j))/\tau}}{\sum_{k=1,k\ne i}^{2N}e^{sim(SiT_{contr}(x_i),SiT_{contr}(x_k))/\tau}} \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (1) lcontrxi​,xj​​(W)=∑k=1,k​=i2N​esim(SiTcontr​(xi​),SiTcontr​(xk​))/τesim(SiTcontr​(xi​),SiTcontr​(xj​))/τ​                  (1)

L = − 1 N ∑ j = 1 N l o g l x j , x j ˉ ( W )                    ( 2 ) \mathcal{L}=-\frac{1}{N}\sum_{j=1}^Nlogl^{x_j,x_{\bar{j}}}(W) \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2) L=−N1​j=1∑N​loglxj​,xjˉ​​(W)                  (2)

SiT论文中的对比损失公式如上所示。其中 x i x_i xi​, x j x_j xj​分别表示两个不同的输入图像, s i m ( ⋅ , ⋅ ) sim(\cdot,\cdot) sim(⋅,⋅)表示余弦相似度,即归一化之后的点积, τ \tau τ是超参数温度, x j x_j xj​和KaTeX parse error: Got function '\bar' with no arguments as subscript at position 3: x_\̲b̲a̲r̲{j}是来自同一原图的两种不同数据增强的输入图像, S i T c o n t r ( ⋅ ) SiT_{contr}(\cdot) SiTcontr​(⋅) 表示从对比头中得到的图像表示,没看过原文的话,就直接理解为输入图像经过一系列神经网络,得到一个 d i m dim dim 维度的特征向量作为图像的特征表示,网络不是本文的重点,重点是怎样根据得到的特征向量计算对比损失

与最近很火的infoNCE对比损失基本一样,只是写法不同。

代码实现

class ContrastiveLoss(nn.Module):
    def __init__(self, batch_size, device='cuda', temperature=0.5):
        super().__init__()
        self.batch_size = batch_size
        self.register_buffer("temperature", torch.tensor(temperature).to(device))			# 超参数 温度
        self.register_buffer("negatives_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool).to(device)).float())		# 主对角线为0,其余位置全为1的mask矩阵
        
    def forward(self, emb_i, emb_j):		# emb_i, emb_j 是来自同一图像的两种不同的预处理方法得到
        z_i = F.normalize(emb_i, dim=1)     # (bs, dim)  --->  (bs, dim)
        z_j = F.normalize(emb_j, dim=1)     # (bs, dim)  --->  (bs, dim)

        representations = torch.cat([z_i, z_j], dim=0)          # repre: (2*bs, dim)
        similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)      # simi_mat: (2*bs, 2*bs)
        
        sim_ij = torch.diag(similarity_matrix, self.batch_size)         # bs
        sim_ji = torch.diag(similarity_matrix, -self.batch_size)        # bs
        positives = torch.cat([sim_ij, sim_ji], dim=0)                  # 2*bs
        
        nominator = torch.exp(positives / self.temperature)             # 2*bs
        denominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature)             # 2*bs, 2*bs
    
        loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))        # 2*bs
        loss = torch.sum(loss_partial) / (2 * self.batch_size)
        return loss

以下是SiT论文的对比损失代码实现,笔者已经将debug过程中得到的张量形状在注释中标注了出来,供大家参考,其中dim是得到的特征向量的维度,bs是批尺寸batch size。

笔者简单画了一张similarity_matrix的图示来说明整个过程。本图以bs==4为例, a , b , c , d a,b,c,d a,b,c,d分别代表同一个batch内的不同样本,下表0和1表示两种不同的图像扩增方法。图中每个方格则是对应行列的图像特征(dim维的向量)表示计算相似度的结果值。

对比损失的PyTorch实现详解

  1. emb_i,emb_j 是来自同一图像的两种不同的预处理方法得到的输入图像的特征表示。首先是通过F.normalize()emb_iemb_j进行归一化。

  2. 然后将二者拼接起来的到维度为2*bs的representations。再将representations分别转换为列向量和行向量计算相似度矩阵similarity_matrix(见图)。

  3. 在通过偏移的对角线(图中蓝线)的到sim_ijsim_ji,并拼接的到positives。请注意蓝线对应的行列坐标,分别是 a 0 , a 1 a_0,a_1 a0​,a1​、 b 0 , b 1 b_0,b_1 b0​,b1​等,即蓝线对应的网格即是来自同一张原图的不同处理的输入图像。这在损失的设计中即是我们的正样本。

  4. 然后nominator(分子)即可根据公式计算的到。

  5. 而在计算denominator时需注意要乘上self.negatives_mask。该变量在__init__中定义,是对2*bs的方针对角阵取反,即主对角线全是0,其余位置全是1 。这是为了在负样本中屏蔽自己与自己的相似度结果(图中红线),即使得similarity_matrix的主对角钱全为0。因为自己与自己的相似度肯定是1,加入到计算中没有意义。

  6. 再到后面loss_partial的计算(第22行)其实是计算出公式(1),torch.sum()计算的是(1)中分母上的 ∑ \sum ∑符号。

  7. 第23行就是计算公式(2),其中与公式相比分母上多了除了个2,是因为本实现为了方便将similarity_matrix的维度扩展为2*bs。即相当于将公式(2)中的 l c o n t r x j , x j ˉ l_{contr}^{x_j,x_{\bar{j}}} lcontrxj​,xjˉ​​​ 和 l c o n t r x j ˉ , x j l_{contr}^{x_{\bar{j}},x_j} lcontrxjˉ​​,xj​​ 分别计算了一遍。所以要多除个2。

自行验证

大家可以将上面的ContrastiveLoss类复制到自己的测试的文件中,并构造几个输入进行测试,打印中间结果,验证自己是否真正地理解了对比损失的代码实现计算过程。

loss_func = losses.ContrastiveLoss(batch_size=4)
emb_i = torch.rand(4, 512).cuda()
emb_j = torch.rand(4, 512).cuda()

loss_contra = loss_func(emb_i, emb_j)
print(loss_contra)
上一篇:task04:卷积情感分析


下一篇:如何入门强化学习