attention注意力机制【对应图的代码讲解】

文章目录

题目

'''
Description: attention注意力机制
Autor: 365JHWZGo
Date: 2021-12-14 17:06:11
LastEditors: 365JHWZGo
LastEditTime: 2021-12-14 22:23:54
'''

注意力机制三步式+分步代码讲解

attention注意力机制【对应图的代码讲解】

导入库

import torch 
import torch.nn as nn
import torch.nn.functional as F

Attn

class Attn(nn.Module):
    def __init__(self,query_size,key_size,value_size1,value_size2,output_size):
        super(Attn,self).__init__()
        self.query_size = query_size
        self.key_size = key_size
        self.value_size1 = value_size1
        self.value_size2 = value_size2
        self.output_size = output_size
        
        self.attn = nn.Linear(self.query_size+self.key_size,value_size1)
        self.attn_combine = nn.Linear(self.query_size+value_size2,output_size)
    
    def forward(self,q,k,v):
        # attn_weights=(1,32)
        attn_weights = F.softmax(self.attn(torch.concat((q[0],k[0]),1)),dim=1)
        # attn_weights.unsqueeze(0)=(1,1,32)
        # v=(1,32,64)
        # attn_applied=(1,1,64)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),v)
        # q[0]=(1,32)
        # attn_applied[0]=(1,64)
        # output=(1,96)
        output = torch.concat((q[0],attn_applied[0]),1)
        # output=(1,1,64)
        output = self.attn_combine(output).unsqueeze(0)
        return output,attn_weights

attn函数是将合成【Query|Key】,进行列合并

attn_conbine函数是生成【Query|attn_applied】,attn_applied是最后Query在SourceSource中的真正注意力分布

attn_weights的结果对应于a1,a2,a3…
attention注意力机制【对应图的代码讲解】
attn_applied是计算Attention Value,bmm相当于a1value1+a2value2+…【矩阵乘法】
attention注意力机制【对应图的代码讲解】
attention注意力机制【对应图的代码讲解】

attention注意力机制【对应图的代码讲解】
attention注意力机制【对应图的代码讲解】
第二个W矩阵是训练得到的参数,维度是d2 x d1,d2是s的hidden state输出维数,d1是hi的hidden state维数
key=h
query=s

if __name__ == "__main__":
    query_size = 32
    key_size = 32
    # value 第二维度
    value_size1 = 32
    # value 第三个维度
    value_size2 = 64
    # 输出维度
    output_size = 64
    attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
    Q = torch.randn(1,1,32)
    K = torch.randn(1,1,32)
    V = torch.randn(1,32,64)
    out = attn(Q, K ,V)
    print(out[0])
    print(out[1])

运行结果

tensor([[[ 0.2658,  0.0392,  0.2432, -0.6333, -0.2197, -0.0189, -0.2440,
           0.2307,  0.3793,  0.1152,  0.3247, -0.0377,  0.5529, -0.2616,
          -0.1077, -0.2078, -0.2510, -0.4814, -0.2096, -0.1568, -0.0288,
           0.0595, -0.2944,  0.1996, -0.2253, -0.1753,  0.3036,  0.4191,
           0.0869, -0.4587,  0.0630, -0.0472,  0.1013,  0.2068,  0.0144,
          -0.5463, -0.0487,  0.2278, -0.2225, -0.2994, -0.2592, -0.0371,
           0.0615,  0.3353, -0.2891, -0.1839,  0.3867,  0.2469,  0.1036,
           0.2699,  0.1983,  0.0683, -0.3410, -0.1992,  0.5660,  0.0794,
          -0.2826,  0.0421,  0.0635,  0.1220,  0.1333, -0.2451, -0.4481,
          -0.1631]]], grad_fn=<UnsqueezeBackward0>)
tensor([[0.0151, 0.0451, 0.0093, 0.0251, 0.0379, 0.0177, 0.0277, 0.0301, 0.0200,
         0.0415, 0.0309, 0.0440, 0.0248, 0.0419, 0.0191, 0.0287, 0.0564, 0.0132,
         0.0442, 0.0473, 0.0359, 0.0154, 0.0195, 0.0652, 0.0255, 0.0178, 0.0287,
         0.0291, 0.0411, 0.0548, 0.0190, 0.0280]], grad_fn=<SoftmaxBackward0>)
上一篇:Oracle项目管理系统之执行预算


下一篇:JDK的安装