https://zhuanlan.zhihu.com/p/342105673
特征处理部分比较好理解,点的self、cross注意力机制实现建议看下源码(MultiHeadedAttention),
def attention(query, key, value): dim = query.shape[1] scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5 prob = torch.nn.functional.softmax(scores, dim=-1) return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob class MultiHeadedAttention(nn.Module): """ Multi-head attention to increase model expressivitiy """ def __init__(self, num_heads: int, d_model: int): super().__init__() assert d_model % num_heads == 0 self.dim = d_model // num_heads self.num_heads = num_heads self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) def forward(self, query, key, value): batch_dim = query.size(0) query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) for l, x in zip(self.proj, (query, key, value))] x, prob = attention(query, key, value) self.prob.append(prob) return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))
这里直接跳到最后的逻辑部分,这部分论文写的比较粗略,需要看下源码才知道在讲啥(也许有大佬不用看)。
看这里,即是说推理时检出的匹配关系是不考虑最后一行和最后一列的,而是设定阈值,将不合格的匹配过滤掉
# Get the matches with score above "match_threshold". max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) indices0, indices1 = max0.indices, max1.indices mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) # [0,0...,1,..0] mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) zero = scores.new_tensor(0) mscores0 = torch.where(mutual0, max0.values.exp(), zero) mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) valid0 = mutual0 & (mscores0 > self.config['match_threshold']) valid1 = mutual1 & valid0.gather(1, indices1) indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
推理时代码如下,可见图A和图B互相匹配的结果(按照score的行列取最大值的index)不必完全一致:
kpts0, kpts1 = pred['keypoints0'].cpu().numpy()[0], pred['keypoints1'].cpu().numpy()[0] matches, conf = pred['matches0'].cpu().detach().numpy(), pred['matching_scores0'].cpu().detach().numpy() image0 = read_image_modified(image0, opt.resize, opt.resize_float) image1 = read_image_modified(image1, opt.resize, opt.resize_float) valid = matches > -1 mkpts0 = kpts0[valid] mkpts1 = kpts1[matches[valid]] mconf = conf[valid]
然后看求解分配矩阵的部分,couplings为相似度得分矩阵,为其添加了最后一行一列,并赋值为1,在原文提到的约束下,使用sinkhorn(待看)算法求解,求出分配矩阵Z,
# b(m+1)(n+1), b(m+1), b(n+1) def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int): """ Perform Sinkhorn Normalization in Log-space for stability""" u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) for _ in range(iters): # [log(m+n) ..., log(n)+log(m+n)] - [] u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2) # b(m+1) v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1) return Z + u.unsqueeze(2) + v.unsqueeze(1) def log_optimal_transport(scores, alpha, iters: int): """ Perform Differentiable Optimal Transport in Log-space for stability""" b, m, n = scores.shape one = scores.new_tensor(1) ms, ns = (m*one).to(scores), (n*one).to(scores) bins0 = alpha.expand(b, m, 1) # only a new view bins1 = alpha.expand(b, 1, n) alpha = alpha.expand(b, 1, 1) # b(m+1)(n+1), 额外行列值为1 couplings = torch.cat([torch.cat([scores, bins0], -1), # bmn,bm1->bm(n+1) torch.cat([bins1, alpha], -1)], 1) # b1n,b11->b1(n+1) norm = - (ms + ns).log() log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) # m+1: [log(m+n) ..., log(n)+log(m+n)] log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm]) # n+1: [log(m+n) ..., log(m)+log(m+n)] log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1) # b(m+1), b(n+1) Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters) Z = Z - norm # multiply probabilities by M+N return Z
损失函数就是最大化这个分配矩阵Z,即下面的scores矩阵,匹配对中肯定不包含dustbin点的,也就是说对dustbin的考量蕴含在sinkhorn中,注意下面的函数调用的参数self.bin_score,这是superglue网络的一个可学习的参数:
bin_score = torch.nn.Parameter(torch.tensor(1.)) self.register_parameter('bin_score', bin_score) 回头看上面的log_optimal_transport代码,每次给couplings的额外行列赋的值就是这个值。all_matches = data['all_matches'].permute(1,2,0) # shape=torch.Size([1, 87, 2]) …… # Run the optimal transport. scores = log_optimal_transport( scores, self.bin_score, iters=self.config['sinkhorn_iterations']) …… # check if indexed correctly loss = [] for i in range(len(all_matches[0])): x = all_matches[0][i][0] y = all_matches[0][i][1] loss.append(-torch.log( scores[0][x][y].exp() )) # check batch size == 1 ?
原文里对分配矩阵的约束如下,这个应该是引入sinkhorn的作用,在代码中分配矩阵P_head并没有显式出现,所以没法辅助我理解这个公式:
相对应的,P的约束就很好理解: