Attention Is All You Need

https://arxiv.org/abs/1706.03762

--------------------------------------------------------

2021-06-03

                                                             Attention Is All You Need

encoder-decoder

attention:对于某个时刻的输出y,它在输入x上各个部分的注意力(理解为权重)

  self-attention:输出序列就是输入序列

  scaled dot-product attention:通过确定Q与K之间的相似程度来选择V

                                          Attention Is All You Need

  除以一个缩放因子:点积得到的结果维度很大,使得结果处于softmax函数梯度很小的区域

                               Attention Is All You Need

class PostionalEncoding(nn.Module):
    def __init__(self, d_model, max_len, device):
        super(PostionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False

        pos = torch.arange(0, max_len, device=device)
        pos = pos.float().unsqueeze(dim=1)

        _2i = torch.arange(0, d_model, step=2, device=device).float()

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))

    def forward(self, x):
        batch, seg_len = x.size()
        return self.encoding[:seg_len, :]


def pad_mask(seq_q,seq_k):
    len_q=seq_q.size(1)
    mask=seq_k.eq(0)
    mask=mask.unsqueeze(1).expand(-1,len_q,-1)
    return mask


def sequence_mask(seq):
    batch,seq_len=seq.size()
    mask=torch.triu(torch.ones((seq_len,seq_len),dtype=torch.uint8),diagonal=1)
    mask=mask.unsqueeze(0).expand(batch,-1,-1)
    return mask


class ScaledDotProductAttention(nn.Module):
    def __init__(self, attention_dropout=0.):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, scale=64 ** -0.5, attn_mask=None):
        attention = torch.bmm(q, k.transpose(1, 2))
        if scale:
            attention = attention * scale
        if attn_mask:
            attention = attention.masked_fill_(attn_mask, -np.inf)
        attention = self.softmax(attention)
        attention = self.dropout(attention)
        context = torch.bmm(attention, v)
        return context, attention


class MultiHeadAttention(nn.Module):
    def __init__(self, model_dim=512, num_heads=8, dropout=0.):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.dim_per_head = model_dim // num_heads

        self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.dot_product_attention = ScaledDotProductAttention(dropout)
        self.linear_final = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(model_dim)

    def forward(self, query, key, value, attn_mask=None):
        residual = query

        dim_per_head = self.dim_per_head
        num_heads = self.num_heads
        batch = key.size(0)

        key = self.linear_k(key)
        value = self.linear_v(value)
        query = self.linear_q(query)

        key = key.view(batch * num_heads, -1, dim_per_head)
        value = value.view(batch * num_heads, -1, dim_per_head)
        query = query.view(batch * num_heads, -1, dim_per_head)

        if attn_mask:
            attn_mask = attn_mask.repeat(num_heads, 1, 1)

        scale = (key.size(-1) // num_heads) ** -0.5

        context, attention = self.dot_product_attention(query, key, value, scale, attn_mask)

        context = context.view(batch, -1, dim_per_head * num_heads)

        output = self.linear_final(context)
        output = self.dropout(output)
        output = self.layer_norm(residual + output)

        return output, attention

 

上一篇:P2926 [USACO08DEC]Patting Heads S


下一篇:PyTorch——自注意力(self-attention)机制实现(代码详解)