回顾原型网络代码②计算模型的精度和损失

文章目录

定义

few_shot.py中定义了函数模型。文件中有两个类和一个函数:

  • def load_protonet_conv(**kwargs):根据传进来的参数kwargs,建立模型
  • class Protonet(nn.Module):主要是计算episodelossacc
  • class Flatten(nn.Module):展平。但是torch不是有这个层吗,不明白为啥作者还要自己写呢

def load_protonet_conv(**kwargs)

回顾原型网络代码②计算模型的精度和损失
当输入为(200,1,28,28)tensor时,模型的输出为(200,64),也就是每个28*28的图片样本转换为一个64维度的向量

使用

line123中使用了模型

engine.train(
   model = model,
   loader = train_loader,
   optim_method = getattr(optim, opt['train.optim_method']),  # Adam
   optim_config = { 'lr': opt['train.learning_rate'], # 学习率0.001
                    'weight_decay': opt['train.weight_decay'] },  # 0.0
   max_epoch = opt['train.epochs']  # 10000
)

计算loss

以 10 w a y − 5 s h o t − 15 q u e r y 10way - 5shot - 15query 10way−5shot−15query为例,在line 40中调用Protonetloss计算了模型的lossacc

def loss(self, sample): 
	"""sample的格式为:
	{
		"class":长度为10的list,    
		"xs":torch.Size([10, 5, 1, 28, 28]),       
		"xq":torch.Size([10, 15, 1, 28, 28])
   }
	"""
    xs = Variable(sample['xs'])  # support  torch.Size([10, 5, 1, 28, 28])
    xq = Variable(sample['xq'])  # query  torch.Size([10, 15, 1, 28, 28])

    n_class = xs.size(0)  # 10
    assert xq.size(0) == n_class
    n_support = xs.size(1)   # 5
    n_query = xq.size(1)  # 15
    
    #  生成query的标签        torch.Size([10]) ->torch.Size([10, 1, 1])  -> torch.Size([10, 15, 1])
    target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long()
    
    target_inds = Variable(target_inds, requires_grad=False)
    if xq.is_cuda:
        target_inds = target_inds.cuda()
        
	# 把 support和query合并到一起进行特征提取,maybe并行化更快	
    x = torch.cat([xs.view(n_class * n_support, *xs.size()[2:]),  # torch.Size([10, 5, 1, 28, 28]) -> torch.Size([50, 1, 28, 28])
                   xq.view(n_class * n_query, *xq.size()[2:])], 0)  # torch.Size([10, 15, 1, 28, 28]) -> torch.Size([150, 1, 28, 28])
    # x.Size([200, 1, 28, 28])
    
    z = self.encoder.forward(x)  # z.size([200, 64])
    z_dim = z.size(-1)
    
    # 求原型
    # torch.Size([50, 64]) -> torch.Size([10, 5, 64]) -> torch.Size([10, 64])
    z_proto = z[:n_class*n_support].view(n_class, n_support, z_dim).mean(1)  
    
    # query的特征 torch.Size([150, 64])
    zq = z[n_class*n_support:]  
    
    # 计算query的特征到原型的距离,等下说这个函数
    dists = euclidean_dist(zq, z_proto)  # torch.Size([150, 10])
    
    # F.log_softmax作用:在softmax的结果上再做多一次log运算,不用softmax据说是为了防止溢出
    # 注意 -dist,下面要考
    log_p_y = F.log_softmax(-dists, dim=1).view(n_class, n_query, -1)  
    #              torch.Size([150, 10]) -> torch.Size([10, 5, 10])

    # 计算损失,先在第2+1个维度上寻找对应标签的距离,例如类1的样本2标签是5,取出它距离原型1的距离,这就是这个样本的产生的loss,然后对所有样本求平均loss
    loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()    
    #             torch.Size([10, 5, 1]) -> torch.Size([10, 5]) -> torch.Size([50]) -> tenser:()
    
    # 计算预测query的标签。根据距离结果,选最小的距离,但是之前-dist,所有这里就是max。
    _, y_hat = log_p_y.max(2)  # y_hat.size() ->([10, 15])
    
    # 根据预测的结果和标签是不是相等来计算精度
    # equal的结果是bool,先转换为float,才能计算平均值
    acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()
    # y_hat和target_inds.squeeze()都是([10, 15]),这样eq的结果也是([10, 15]),求完平均值就是tenser:()
    
    # 返回结果
    return loss_val, {
        'loss': loss_val.item(),
        'acc': acc_val.item()
    }
上一篇:2021-11-02


下一篇:《PyTorch深度学习实践》第二节