nn.TransformerDecoderLayer

import torch
import torch.nn as nn

decode_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)  # d_model is the input feature, nhead is the number of head in the multiheadattention
memory = torch.ones(10,32,512)  # the sequence from the last layer of the encoder ; 可以类比为: batch_size * seqence_length * hidden_size
tgt = torch.zeros(20,20,512)  # the sequence to the decoder layer
out = decode_layer(tgt,memory)
print(out.shape)# 20*20*512

Details: TransformerDecoderLayer — PyTorch 1.10.0 documentation

如下面一个网络: 选用了Roberta 作为 encoder and the decoder is 6-layers Transformer.

encoder = model_class.from_pretrained(args.model_name_or_path,config=config)  # RobertaModel 当作一个 encoder, 加载的model为: roberta
decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads) # d_model = 768, nhead= 12---the number of heads in the multiheadattention models
decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)

Details for TRANSFORMERDECODER: TransformerDecoder — PyTorch 1.10.0 documentation

上一篇:火狐进程里面有但是没有打开浏览器


下一篇:pytorch的杂七杂八