对比学习(Contrastive Learning) (2)

《Supervised Contrastive Learning》

该工作将原来自监督学习(self-supervised)的对比学习思想扩展到全监督学习(full-supervised),相比于上一篇SimCLR,SupCon在数据增强,encoder,投影网络的设计上没什么区别,在正样本与负样本的定义上稍有不同:在一个batch中,对于每一个选择的锚样本,与其属于同一类的都认为是正样本,不属于同一类的都认为是负样本。

监督学习的损失函数
L o u t s u p = ∑ i ∈ I L o u t , i s u p = ∑ i ∈ I − 1 ∣ P ( i ) ∣ ∑ p ∈ P ( i ) log ⁡ exp ⁡ ( z i ⋅ z p / τ ) ∑ α ∈ A ( i ) exp ⁡ ( z i ⋅ z α / τ ) \mathcal{L}_{out}^{sup}=\sum_{i\in I}\mathcal{L}_{out, i}^{sup}=\sum_{i\in I}\frac{-1}{|P(i)|}\sum_{p\in P(i)}\log \frac{\exp (z_i \cdot z_p/\tau)}{\sum_{\alpha \in A(i)}\exp (z_i \cdot z_{\alpha}/\tau)} Loutsup​=i∈I∑​Lout,isup​=i∈I∑​∣P(i)∣−1​p∈P(i)∑​log∑α∈A(i)​exp(zi​⋅zα​/τ)exp(zi​⋅zp​/τ)​
与自监督学习稍有不同的是,监督学习中在计算每个 L o u t , i s u p \mathcal{L}_{out, i}^{sup} Lout,isup​的时候,会对每个与当前的样本 i i i属于同一类的样本计算一遍 log ⁡ exp ⁡ ( z i ⋅ z p / τ ) ∑ α ∈ A ( i ) exp ⁡ ( z i ⋅ z α / τ ) \log \frac{\exp (z_i \cdot z_p/\tau)}{\sum_{\alpha \in A(i)}\exp (z_i \cdot z_{\alpha}/\tau)} log∑α∈A(i)​exp(zi​⋅zα​/τ)exp(zi​⋅zp​/τ)​,即 p p p为与 i i i属于同一个类的样本,累加求和之后除以基数 P ( i ) P(i) P(i),这里面暗含了对于标签信息的利用,同类的样本在特征空间被拉近,不同类的则远离。

监督对比学习的代码及注解(loss部分)

以一个输入特征大小为features = torch.rand((3, 2, 5)),对应标签为labels = [7, 7, 6]为例,对程序细节的实现会促进对损失函数的理解。

class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf."""

    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    # 以batchsize=3为例,features为:[ a1, a2
    #                                b1, b2
    #                                c1, c2]
    # labels 为[l1, l2, l3]
    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. 
        Args:
            bsz是 batch_size
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is not None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            # torch.eq对相同位置的两个元素进行比较,如果大小一致,则返回true,否则返回false
            mask = torch.eq(labels, labels.T).float().to(device)
            '''
                labels = [7, 7, 6]
                labels.shape = [3, 1]
                此时的mask:  [1., 1., 0.],
                            [1., 1., 0.],
                            [0., 0., 1.]
            '''
        else:
            mask = mask.float().to(device)
        # contrast_count表示有几个view
        # 将两个view的features按照view排列
        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)

        # 对比学习的模式,即使用一个view的锚样本还是使用两个view作为锚样本进行计算
        if self.contrast_mode == 'one':
            # 只使用一个batch里面的第一个view作为锚样本
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            # 使用这一个batch的所有样本作为锚样本
            anchor_feature = contrast_feature
            anchor_count = contrast_count
            '''
                anchor_feature.shape = [6, 5]
                contrast_feature = anchor_feature
                anchor_count = 2
            '''
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
        # 计算相似度矩阵
        '''
                                     由  [ a1, a2
                                          b1, b2
                                          c1, c2]
                                     变为    
                                          [a1a1 a1b1 a1c1 a1a2 a1b2 a1c2
                                           b1a1 b1b1 b1c1 b1a2 b1b2 b1c2
                                           c1a1 c1b1 c1c1 c1a2 c1b2 c1c2
                                           a2a1 a2b1 a2c1 a2a2 a2b2 a2c2
                                           b2a1 b2b1 b2c1 b2a2 b2b2 b2c2
                                           c2a1 c2b1 c2c1 c2a2 c2b2 c2c2]
        '''
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)

        # for numerical stability,对于每一个锚样本计算与所有feature的乘积,并选出最大的那一个
        # 数值稳定性先不看
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask,总共铺满 4个原来的 mask
        mask = mask.repeat(anchor_count, contrast_count)

        # mask-out self-contrast cases, 矩阵大小为[6, 6],对角线元素全为0,目的是排除锚样本自己
        '''
        logits_mask = 
                    [0., 1., 1., 1., 1., 1.],
                    [1., 0., 1., 1., 1., 1.],
                    [1., 1., 0., 1., 1., 1.],
                    [1., 1., 1., 0., 1., 1.],
                    [1., 1., 1., 1., 0., 1.],
                    [1., 1., 1., 1., 1., 0.]
        mask = [[1., 1., 0., 1., 1., 0.],
                [1., 1., 0., 1., 1., 0.],
                [0., 0., 1., 0., 0., 1.],
                [1., 1., 0., 1., 1., 0.],
                [1., 1., 0., 1., 1., 0.],
                [0., 0., 1., 0., 0., 1.]]
        mask * logits_mask = [[0., 1., 0., 1., 1., 0.],
                              [1., 0., 0., 1., 1., 0.],
                              [0., 0., 0., 0., 0., 1.],
                              [1., 1., 0., 0., 1., 0.],
                              [1., 1., 0., 1., 0., 0.],
                              [0., 0., 1., 0., 0., 0.]]
        '''
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        # z在原来的mask上面排除锚样本与自己相乘的情况
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        # 这一句存疑,为什么不对logits取对数?log a/b = log a - log b
        # log_prob的结果相当于相似度矩阵中的每个元素都除以损失函数的分母  ===》每个锚样本所对应的负样本乘积之和
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
        
        # 对正样本进行累加,对应损失函数公式内部的sum,在数值稳定性用了每行减最大值来代替log吗?
        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
        '''
        与锚样本同类的所有样本,除以锚样本与负样本的乘积之和,之后累加,除以该锚样本所对应的正样本的个数
            mean_log_prob_pos = 
            [(a1b1/(a1b1 + a1c1 + a1a2 + a1b2 + a1c2) + a1a2/(a1b1 + a1c1 + a1a2 + a1b2 + a1c2) + a1b2/(a1b1 + a1c1 + a1a2 + a1b2 + a1c2))/3,
             (b1a1/(b1a1 + b1c1 + b1a2 + b1b2 + b1c2) + b1a2/(b1a1 + b1c1 + b1a2 + b1b2 + b1c2) + b1b2/(b1a1 + b1c1 + b1a2 + b1b2 + b1c2))/3,
             ..........]
        '''

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss
上一篇:元注解


下一篇:Jetpack的ViewModel与LiveData总结