文章目录
定义
few_shot.py中定义了函数模型。文件中有两个类和一个函数:
-
def load_protonet_conv(**kwargs)
:根据传进来的参数kwargs
,建立模型 -
class Protonet(nn.Module)
:主要是计算episode
的loss
和acc
-
class Flatten(nn.Module)
:展平。但是torch不是有这个层吗,不明白为啥作者还要自己写呢
def load_protonet_conv(**kwargs)
当输入为(200,1,28,28)
的tensor
时,模型的输出为(200,64)
,也就是每个28*28的图片样本转换为一个64维度的向量
使用
在line123中使用了模型
engine.train(
model = model,
loader = train_loader,
optim_method = getattr(optim, opt['train.optim_method']), # Adam
optim_config = { 'lr': opt['train.learning_rate'], # 学习率0.001
'weight_decay': opt['train.weight_decay'] }, # 0.0
max_epoch = opt['train.epochs'] # 10000
)
计算loss
以
10
w
a
y
−
5
s
h
o
t
−
15
q
u
e
r
y
10way - 5shot - 15query
10way−5shot−15query为例,在line 40中调用Protonet
的loss
计算了模型的loss
和acc
def loss(self, sample):
"""sample的格式为:
{
"class":长度为10的list,
"xs":torch.Size([10, 5, 1, 28, 28]),
"xq":torch.Size([10, 15, 1, 28, 28])
}
"""
xs = Variable(sample['xs']) # support torch.Size([10, 5, 1, 28, 28])
xq = Variable(sample['xq']) # query torch.Size([10, 15, 1, 28, 28])
n_class = xs.size(0) # 10
assert xq.size(0) == n_class
n_support = xs.size(1) # 5
n_query = xq.size(1) # 15
# 生成query的标签 torch.Size([10]) ->torch.Size([10, 1, 1]) -> torch.Size([10, 15, 1])
target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long()
target_inds = Variable(target_inds, requires_grad=False)
if xq.is_cuda:
target_inds = target_inds.cuda()
# 把 support和query合并到一起进行特征提取,maybe并行化更快
x = torch.cat([xs.view(n_class * n_support, *xs.size()[2:]), # torch.Size([10, 5, 1, 28, 28]) -> torch.Size([50, 1, 28, 28])
xq.view(n_class * n_query, *xq.size()[2:])], 0) # torch.Size([10, 15, 1, 28, 28]) -> torch.Size([150, 1, 28, 28])
# x.Size([200, 1, 28, 28])
z = self.encoder.forward(x) # z.size([200, 64])
z_dim = z.size(-1)
# 求原型
# torch.Size([50, 64]) -> torch.Size([10, 5, 64]) -> torch.Size([10, 64])
z_proto = z[:n_class*n_support].view(n_class, n_support, z_dim).mean(1)
# query的特征 torch.Size([150, 64])
zq = z[n_class*n_support:]
# 计算query的特征到原型的距离,等下说这个函数
dists = euclidean_dist(zq, z_proto) # torch.Size([150, 10])
# F.log_softmax作用:在softmax的结果上再做多一次log运算,不用softmax据说是为了防止溢出
# 注意 -dist,下面要考
log_p_y = F.log_softmax(-dists, dim=1).view(n_class, n_query, -1)
# torch.Size([150, 10]) -> torch.Size([10, 5, 10])
# 计算损失,先在第2+1个维度上寻找对应标签的距离,例如类1的样本2标签是5,取出它距离原型1的距离,这就是这个样本的产生的loss,然后对所有样本求平均loss
loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
# torch.Size([10, 5, 1]) -> torch.Size([10, 5]) -> torch.Size([50]) -> tenser:()
# 计算预测query的标签。根据距离结果,选最小的距离,但是之前-dist,所有这里就是max。
_, y_hat = log_p_y.max(2) # y_hat.size() ->([10, 15])
# 根据预测的结果和标签是不是相等来计算精度
# equal的结果是bool,先转换为float,才能计算平均值
acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()
# y_hat和target_inds.squeeze()都是([10, 15]),这样eq的结果也是([10, 15]),求完平均值就是tenser:()
# 返回结果
return loss_val, {
'loss': loss_val.item(),
'acc': acc_val.item()
}