填充掩码就是用来指示哪些数据是真实的,哪些是填充的。在模型处理这些数据时,掩码会用来避免在计算损失或者梯度时考虑填充的部分,确保模型的学习只关注于有效的数据。在使用诸如Transformer这样的模型时,填充掩码特别重要,因为它们可以帮助模型在进行自注意力计算时忽略掉填充的位置。
importtorch
defcreate_padding_mask(seq, pad_token=0):
mask= (seq==pad_token).unsqueeze(1).unsqueeze(2)
returnmask # (batch_size, 1, 1, seq_len)
# Example usage
seq=torch.tensor([[7, 6, 0, 0], [1, 2, 3, 0]])
padding_mask=create_padding_mask(seq)
print(padding_mask)