1.transformer模型
Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,现在比较火热的 Bert 也是基于 Transformer。Transformer 模型使用了 Self-Attention 机制,不采用 RNN 的顺序结构,使得模型可以并行化训练,而且能够拥有全局信息。
2.encoder部分实现(pytorch)
class EncoderLayer(nn.Module):
def __init__(self, hidden_size, filter_size, n_head, pre_lnorm, device, dropout):
super(EncoderLayer, self).__init__()
# self-attention part
self.self_attn = MultiHeadAttention(hidden_size, n_head, device)
self.self_attn_norm = nn.LayerNorm(hidden_size)
# feed forward network part
self.pff = PositionwiseFeedForward(hidden_size, filter_size, dropout)
self.pff_norm = nn.LayerNorm(hidden_size)
self.pre_lnorm = pre_lnorm
def forward(self, src, src_mask):
if self.pre_lnorm:
pre = self.self_attn_norm(src)
# residual connection
src = src + self.self_attn(pre, pre, pre, src_mask)
pre = self.pff_norm(src)
src = src + self.pff(pre) # residual connection
else:
# residual connection + layerNorm
src = self.self_attn_norm(
src + self.self_attn(src, src, src, src_mask))
# residual connection + layerNorm
src = self.pff_norm(src + self.pff(src))
return src
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, filter_size, n_head, dropout, n_layers, pre_lnorm, device):
super(Encoder, self).__init__()
self.hidden_size = hidden_size
self.embed_scale = hidden_size ** 0.5
self.wte = nn.Embedding(input_size, hidden_size) # token embeddings
# self.wpe = PositionalEmbedding(hidden_size) # positional embeddings
# self.wpe = nn.Embedding(1000, hidden_size)
# self.wpe = PositionalEncoding(hidden_size)
max_len = 1000
self.wpe = nn.Embedding.from_pretrained(positional_encoding_table(
max_len+1, hidden_size, padding_idx=ZH.vocab.stoi['<pad>']), freeze=True)
self.embed_dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList([EncoderLayer(hidden_size, filter_size, n_head, pre_lnorm, device, dropout)
for _ in range(n_layers)])
self.pre_lnorm = pre_lnorm
self.last_norm = nn.LayerNorm(hidden_size)
self.device = device
def forward(self, src, src_mask):
# token embedding + positional encoding
# pos = torch.arange(src.shape[1], dtype=torch.float32).to(self.device)
pos = torch.arange(0, src.shape[1]).unsqueeze(
0).repeat(src.shape[0], 1).to(self.device)
src = self.wte(src) * self.embed_scale + self.wpe(pos) # [B, T, H]
src = self.embed_dropout(src)
for layer in self.layers:
src = layer(src, src_mask)
if self.pre_lnorm:
src = self.last_norm(src)
return src
3.decoder部分实现
class DecoderLayer(nn.Module):
def __init__(self, hidden_size, filter_size, n_head, pre_lnorm, device, dropout):
super(DecoderLayer, self).__init__()
# self-attention part
self.self_attn = MultiHeadAttention(hidden_size, n_head, device)
self.self_attn_norm = nn.LayerNorm(hidden_size)
# encoder-to-decoder self-attention part
self.ed_self_attn = MultiHeadAttention(hidden_size, n_head, device)
self.ed_self_attn_norm = nn.LayerNorm(hidden_size)
# feed forward network part
self.pff = PositionwiseFeedForward(hidden_size, filter_size, dropout)
self.pff_norm = nn.LayerNorm(hidden_size)
self.pre_lnorm = pre_lnorm
def forward(self, enc_out, enc_out_mask, trg, trg_mask):
if self.pre_lnorm:
# print("iftrg",trg.shape,self.pre_lnorm)
ris = self.self_attn_norm(trg)
trg = trg + self.self_attn(ris, ris, ris, trg_mask)
ris = self.ed_self_attn_norm(trg)
trg = trg + self.ed_self_attn(ris, enc_out, enc_out, enc_out_mask)
ris = self.pff_norm(trg)
trg = trg + self.pff(ris)
else:
# print("trg",trg.shape,trg_mask.shape,self.pre_lnorm)
trg = self.self_attn_norm(
trg + self.self_attn(trg, trg, trg, trg_mask))
trg = self.ed_self_attn_norm(
trg + self.ed_self_attn(trg, enc_out, enc_out, enc_out_mask))
trg = self.pff_norm(trg + self.pff(trg))
return trg
class Decoder(nn.Module):
def __init__(self, input_size, hidden_size, filter_size, n_head, dropout, n_layers, pre_lnorm, device):
super(Decoder, self).__init__()
self.hidden_size = hidden_size
self.dropout = dropout
self.embed_scale = hidden_size ** 0.5
self.wte = nn.Embedding(input_size, hidden_size) # token embeddings
# self.wpe = PositionalEmbedding(hidden_size) # positional embeddings
# self.wpe = nn.Embedding(1000, hidden_size)
# self.wpe = PositionalEncoding(hidden_size)
max_len = 1000
self.wpe = nn.Embedding.from_pretrained(positional_encoding_table(
max_len+1, hidden_size, padding_idx=ENG.vocab.stoi['<pad>']), freeze=True)
self.embed_dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList([DecoderLayer(hidden_size, filter_size, n_head, pre_lnorm, device, dropout)
for _ in range(n_layers)])
self.pre_lnorm = pre_lnorm
self.last_norm = nn.LayerNorm(hidden_size)
self.device = device
def forward(self, enc_out, enc_out_mask, trg, trg_mask):
# token embedding + positional encoding
# pos = torch.arange(trg.shape[1], dtype=torch.float32).to(self.device)
pos = torch.arange(0, trg.shape[1]).unsqueeze(
0).repeat(trg.shape[0], 1).to(self.device)
trg = self.wte(trg) * self.embed_scale + self.wpe(pos) # [B, T, H]
trg = self.embed_dropout(trg)
#trg [B, T, H]
for layer in self.layers:
trg = layer(enc_out, enc_out_mask, trg, trg_mask)
if self.pre_lnorm:
trg = self.last_norm(trg)
return trg