https://arxiv.org/abs/1706.03762
--------------------------------------------------------
2021-06-03
encoder-decoder
attention:对于某个时刻的输出y,它在输入x上各个部分的注意力(理解为权重)
self-attention:输出序列就是输入序列
scaled dot-product attention:通过确定Q与K之间的相似程度来选择V
除以一个缩放因子:点积得到的结果维度很大,使得结果处于softmax函数梯度很小的区域
class PostionalEncoding(nn.Module): def __init__(self, d_model, max_len, device): super(PostionalEncoding, self).__init__() self.encoding = torch.zeros(max_len, d_model, device=device) self.encoding.requires_grad = False pos = torch.arange(0, max_len, device=device) pos = pos.float().unsqueeze(dim=1) _2i = torch.arange(0, d_model, step=2, device=device).float() self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model))) self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model))) def forward(self, x): batch, seg_len = x.size() return self.encoding[:seg_len, :] def pad_mask(seq_q,seq_k): len_q=seq_q.size(1) mask=seq_k.eq(0) mask=mask.unsqueeze(1).expand(-1,len_q,-1) return mask def sequence_mask(seq): batch,seq_len=seq.size() mask=torch.triu(torch.ones((seq_len,seq_len),dtype=torch.uint8),diagonal=1) mask=mask.unsqueeze(0).expand(batch,-1,-1) return mask class ScaledDotProductAttention(nn.Module): def __init__(self, attention_dropout=0.): super(ScaledDotProductAttention, self).__init__() self.dropout = nn.Dropout(attention_dropout) self.softmax = nn.Softmax(dim=2) def forward(self, q, k, v, scale=64 ** -0.5, attn_mask=None): attention = torch.bmm(q, k.transpose(1, 2)) if scale: attention = attention * scale if attn_mask: attention = attention.masked_fill_(attn_mask, -np.inf) attention = self.softmax(attention) attention = self.dropout(attention) context = torch.bmm(attention, v) return context, attention class MultiHeadAttention(nn.Module): def __init__(self, model_dim=512, num_heads=8, dropout=0.): super(MultiHeadAttention, self).__init__() self.num_heads = num_heads self.dim_per_head = model_dim // num_heads self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads) self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads) self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads) self.dot_product_attention = ScaledDotProductAttention(dropout) self.linear_final = nn.Linear(model_dim, model_dim) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(model_dim) def forward(self, query, key, value, attn_mask=None): residual = query dim_per_head = self.dim_per_head num_heads = self.num_heads batch = key.size(0) key = self.linear_k(key) value = self.linear_v(value) query = self.linear_q(query) key = key.view(batch * num_heads, -1, dim_per_head) value = value.view(batch * num_heads, -1, dim_per_head) query = query.view(batch * num_heads, -1, dim_per_head) if attn_mask: attn_mask = attn_mask.repeat(num_heads, 1, 1) scale = (key.size(-1) // num_heads) ** -0.5 context, attention = self.dot_product_attention(query, key, value, scale, attn_mask) context = context.view(batch, -1, dim_per_head * num_heads) output = self.linear_final(context) output = self.dropout(output) output = self.layer_norm(residual + output) return output, attention