【自然语言处理五】Transformer模型

1.transformer模型

Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,现在比较火热的 Bert 也是基于 Transformer。Transformer 模型使用了 Self-Attention 机制,不采用 RNN 的顺序结构,使得模型可以并行化训练,而且能够拥有全局信息。

【自然语言处理五】Transformer模型

2.encoder部分实现(pytorch)

class EncoderLayer(nn.Module):
    def __init__(self, hidden_size, filter_size, n_head, pre_lnorm, device, dropout):
        super(EncoderLayer, self).__init__()
        # self-attention part
        self.self_attn = MultiHeadAttention(hidden_size, n_head, device)
        self.self_attn_norm = nn.LayerNorm(hidden_size)

        # feed forward network part
        self.pff = PositionwiseFeedForward(hidden_size, filter_size, dropout)
        self.pff_norm = nn.LayerNorm(hidden_size)

        self.pre_lnorm = pre_lnorm

    def forward(self, src, src_mask):
        if self.pre_lnorm:
            pre = self.self_attn_norm(src)
            # residual connection
            src = src + self.self_attn(pre, pre, pre, src_mask)

            pre = self.pff_norm(src)
            src = src + self.pff(pre)  # residual connection
        else:
            # residual connection + layerNorm
            src = self.self_attn_norm(
                src + self.self_attn(src, src, src, src_mask))
            # residual connection + layerNorm
            src = self.pff_norm(src + self.pff(src))

        return src


class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, filter_size, n_head, dropout, n_layers, pre_lnorm, device):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.embed_scale = hidden_size ** 0.5
        self.wte = nn.Embedding(input_size, hidden_size)  # token embeddings
        # self.wpe = PositionalEmbedding(hidden_size) # positional embeddings
        # self.wpe = nn.Embedding(1000, hidden_size)
        # self.wpe = PositionalEncoding(hidden_size)
        max_len = 1000
        self.wpe = nn.Embedding.from_pretrained(positional_encoding_table(
            max_len+1, hidden_size, padding_idx=ZH.vocab.stoi['<pad>']), freeze=True)
        self.embed_dropout = nn.Dropout(dropout)
        self.layers = nn.ModuleList([EncoderLayer(hidden_size, filter_size, n_head, pre_lnorm, device, dropout)
                                     for _ in range(n_layers)])
        self.pre_lnorm = pre_lnorm
        self.last_norm = nn.LayerNorm(hidden_size)
        self.device = device

    def forward(self, src, src_mask):
        # token embedding + positional encoding
        # pos = torch.arange(src.shape[1], dtype=torch.float32).to(self.device)
        pos = torch.arange(0, src.shape[1]).unsqueeze(
            0).repeat(src.shape[0], 1).to(self.device)
        src = self.wte(src) * self.embed_scale + self.wpe(pos)  # [B, T, H]
        src = self.embed_dropout(src)

        for layer in self.layers:
            src = layer(src, src_mask)

        if self.pre_lnorm:
            src = self.last_norm(src)

        return src        

3.decoder部分实现

class DecoderLayer(nn.Module):
    def __init__(self, hidden_size, filter_size, n_head, pre_lnorm, device, dropout):
        super(DecoderLayer, self).__init__()
        # self-attention part
        self.self_attn = MultiHeadAttention(hidden_size, n_head, device)
        self.self_attn_norm = nn.LayerNorm(hidden_size)

        # encoder-to-decoder self-attention part
        self.ed_self_attn = MultiHeadAttention(hidden_size, n_head, device)
        self.ed_self_attn_norm = nn.LayerNorm(hidden_size)

        # feed forward network part
        self.pff = PositionwiseFeedForward(hidden_size, filter_size, dropout)
        self.pff_norm = nn.LayerNorm(hidden_size)

        self.pre_lnorm = pre_lnorm

    def forward(self, enc_out, enc_out_mask, trg, trg_mask):
        if self.pre_lnorm:
          #            print("iftrg",trg.shape,self.pre_lnorm)
            ris = self.self_attn_norm(trg)
            trg = trg + self.self_attn(ris, ris, ris, trg_mask)

            ris = self.ed_self_attn_norm(trg)
            trg = trg + self.ed_self_attn(ris, enc_out, enc_out, enc_out_mask)

            ris = self.pff_norm(trg)
            trg = trg + self.pff(ris)
        else:
          #            print("trg",trg.shape,trg_mask.shape,self.pre_lnorm)
            trg = self.self_attn_norm(
                trg + self.self_attn(trg, trg, trg, trg_mask))

            trg = self.ed_self_attn_norm(
                trg + self.ed_self_attn(trg, enc_out, enc_out, enc_out_mask))
            trg = self.pff_norm(trg + self.pff(trg))

        return trg

class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size, filter_size, n_head, dropout, n_layers, pre_lnorm, device):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.embed_scale = hidden_size ** 0.5
        self.wte = nn.Embedding(input_size, hidden_size)  # token embeddings
        # self.wpe = PositionalEmbedding(hidden_size) # positional embeddings
        # self.wpe = nn.Embedding(1000, hidden_size)
        # self.wpe = PositionalEncoding(hidden_size)
        max_len = 1000
        self.wpe = nn.Embedding.from_pretrained(positional_encoding_table(
            max_len+1, hidden_size, padding_idx=ENG.vocab.stoi['<pad>']), freeze=True)
        self.embed_dropout = nn.Dropout(dropout)
        self.layers = nn.ModuleList([DecoderLayer(hidden_size, filter_size, n_head, pre_lnorm, device, dropout)
                                     for _ in range(n_layers)])
        self.pre_lnorm = pre_lnorm
        self.last_norm = nn.LayerNorm(hidden_size)
        self.device = device

    def forward(self, enc_out, enc_out_mask, trg, trg_mask):
        # token embedding + positional encoding
        # pos = torch.arange(trg.shape[1], dtype=torch.float32).to(self.device)
        pos = torch.arange(0, trg.shape[1]).unsqueeze(
            0).repeat(trg.shape[0], 1).to(self.device)
        trg = self.wte(trg) * self.embed_scale + self.wpe(pos)  # [B, T, H]
        trg = self.embed_dropout(trg)

        #trg [B, T, H]
        for layer in self.layers:
            trg = layer(enc_out, enc_out_mask, trg, trg_mask)

        if self.pre_lnorm:
            trg = self.last_norm(trg)
        return trg
        
上一篇:Hybris里类似ABAP Netweaver的DDIC - 如何做data type的extension


下一篇:java设计模式之门面模式