文章目录
题目
'''
Description: attention注意力机制
Autor: 365JHWZGo
Date: 2021-12-14 17:06:11
LastEditors: 365JHWZGo
LastEditTime: 2021-12-14 22:23:54
'''
注意力机制三步式+分步代码讲解
导入库
import torch
import torch.nn as nn
import torch.nn.functional as F
Attn类
class Attn(nn.Module):
def __init__(self,query_size,key_size,value_size1,value_size2,output_size):
super(Attn,self).__init__()
self.query_size = query_size
self.key_size = key_size
self.value_size1 = value_size1
self.value_size2 = value_size2
self.output_size = output_size
self.attn = nn.Linear(self.query_size+self.key_size,value_size1)
self.attn_combine = nn.Linear(self.query_size+value_size2,output_size)
def forward(self,q,k,v):
# attn_weights=(1,32)
attn_weights = F.softmax(self.attn(torch.concat((q[0],k[0]),1)),dim=1)
# attn_weights.unsqueeze(0)=(1,1,32)
# v=(1,32,64)
# attn_applied=(1,1,64)
attn_applied = torch.bmm(attn_weights.unsqueeze(0),v)
# q[0]=(1,32)
# attn_applied[0]=(1,64)
# output=(1,96)
output = torch.concat((q[0],attn_applied[0]),1)
# output=(1,1,64)
output = self.attn_combine(output).unsqueeze(0)
return output,attn_weights
attn函数是将合成【Query|Key】,进行列合并
attn_conbine函数是生成【Query|attn_applied】,attn_applied是最后Query在SourceSource中的真正注意力分布
attn_weights的结果对应于a1,a2,a3…
attn_applied是计算Attention Value,bmm相当于a1value1+a2value2+…【矩阵乘法】
第二个W矩阵是训练得到的参数,维度是d2 x d1,d2是s的hidden state输出维数,d1是hi的hidden state维数
key=h
query=s
if __name__ == "__main__":
query_size = 32
key_size = 32
# value 第二维度
value_size1 = 32
# value 第三个维度
value_size2 = 64
# 输出维度
output_size = 64
attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
Q = torch.randn(1,1,32)
K = torch.randn(1,1,32)
V = torch.randn(1,32,64)
out = attn(Q, K ,V)
print(out[0])
print(out[1])
运行结果
tensor([[[ 0.2658, 0.0392, 0.2432, -0.6333, -0.2197, -0.0189, -0.2440,
0.2307, 0.3793, 0.1152, 0.3247, -0.0377, 0.5529, -0.2616,
-0.1077, -0.2078, -0.2510, -0.4814, -0.2096, -0.1568, -0.0288,
0.0595, -0.2944, 0.1996, -0.2253, -0.1753, 0.3036, 0.4191,
0.0869, -0.4587, 0.0630, -0.0472, 0.1013, 0.2068, 0.0144,
-0.5463, -0.0487, 0.2278, -0.2225, -0.2994, -0.2592, -0.0371,
0.0615, 0.3353, -0.2891, -0.1839, 0.3867, 0.2469, 0.1036,
0.2699, 0.1983, 0.0683, -0.3410, -0.1992, 0.5660, 0.0794,
-0.2826, 0.0421, 0.0635, 0.1220, 0.1333, -0.2451, -0.4481,
-0.1631]]], grad_fn=<UnsqueezeBackward0>)
tensor([[0.0151, 0.0451, 0.0093, 0.0251, 0.0379, 0.0177, 0.0277, 0.0301, 0.0200,
0.0415, 0.0309, 0.0440, 0.0248, 0.0419, 0.0191, 0.0287, 0.0564, 0.0132,
0.0442, 0.0473, 0.0359, 0.0154, 0.0195, 0.0652, 0.0255, 0.0178, 0.0287,
0.0291, 0.0411, 0.0548, 0.0190, 0.0280]], grad_fn=<SoftmaxBackward0>)