小黑算法成长日记23:selfAttention与multiHeadAttention

SelfAttention操作

从单个字的角度:

q i = h i W Q , k j = h j W K , v j = h j W V q_i = h_iW_Q,k_j = h_jW_K,v_j = h_jW_V qi​=hi​WQ​,kj​=hj​WK​,vj​=hj​WV​

e i j = q i k j T e_{ij} = q_ik_j^T eij​=qi​kjT​

α i = S o f t m a x ( [ e i , 1 , . . . , e i , T ] ) \alpha_i = Softmax([e_{i,1},...,e_{i,T}]) αi​=Softmax([ei,1​,...,ei,T​])

h i ′ = ( ∑ j = 1 T α i , j v j ) W 0 h'_i = (\sum_{j=1}^T \alpha_{i,j}v_j)W_0 hi′​=(∑j=1T​αi,j​vj​)W0​


矩阵的形式:

Q = H W Q , K = H W K , V = H W V Q = HW_Q,K = HW_K,V = HW_V Q=HWQ​,K=HWK​,V=HWV​

E = Q K T E = QK^T E=QKT

E ′ = S o f t m a x ( E ) E' = Softmax(E) E′=Softmax(E)

H ′ = E ′ V H' = E'V H′=E′V

单头selfAttention

import math
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
    def __init__(self,d_model,d_head):
        super(SelfAttention,self).__init__()
        self.w_q = nn.Linear(d_model,d_head)
        self.w_k = nn.Linear(d_model,d_head)
        self.w_v = nn.Linear(d_model,d_head)
        self.w_o = nn.Linear(d_head,d_model)
    def forward(self,x):
        # x:[batch_size,max_len,model_dim]
        # q,k,v:[batch_size,max_len,d_head]
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)
        
        attn_score = torch.matmul(q,k.permute(0,2,1))   # 注意这里不是reshape
        attn_score = torch.softmax(attn_score,dim = -1)    # [batch_size,max_len,max_len]
        output = torch.matmul(attn_score,v)    # [batch_size,max_len,d_head]
        return self.w_o(output)
x = torch.randn(3,9,100)
model = SelfAttention(100,80)
model(x).shape

多头selfAttention

# 多头selfattention
class MultiHeadSelfAttention(nn.Module):
    def __init__(self,d_model = 768,d_head = 64):
        super(MultiHeadSelfAttention,self).__init__()
        assert d_model % d_head == 0
        self.w_q = nn.Linear(d_model,d_model)
        self.w_k = nn.Linear(d_model,d_model)
        self.w_v = nn.Linear(d_model,d_model)
        self.w_o = nn.Linear(d_model,d_model)
        
        self.n_heads = int(d_model // d_head)
        
        self.d_model = d_model
        self.d_head = d_head
    def forward(self,x,mask = None):
        batch_size = x.shape[0]
        max_len = x.shape[1]
        q = self.w_q(x).view(batch_size,max_len,self.n_heads,self.d_head)
        k = self.w_k(x).view(batch_size,max_len,self.n_heads,self.d_head)
        v = self.w_v(x).view(batch_size,max_len,self.n_heads,self.d_head)
        
        q = q.permute(0,2,1,3)
        k = k.permute(0,2,1,3)
        v = v.permute(0,2,1,3)    # [batch_size,num_head,max_len,d_head]
        
        attn_score = torch.matmul(q,k.permute(0,1,3,2))
        
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(-1)    # [batch_size,1,max_len,1]
            attn_score = attn_score.masked_fill(mask == 0,-1e-25)  
        attn_score = torch.softmax(attn_score,-1)    # [batch_size,num_head,max_len,max_len]
        out = torch.matmul(attn_score,v).permute(0,2,1,3)
        out = out.contiguous().view(batch_size,max_len,-1)
        return self.w_o(out)
if __name__ == "__main__":
    x = torch.randn(2, 9, 768)
    mask = torch.tensor([
        [1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0],
    ]).bool()

    model = MultiHeadSelfAttention()
    print(model(x,mask).shape)
上一篇:第一安装oracle数据库后,需要创建一个用户,给用户解锁并赋予权限


下一篇:GoF 的 23 种设计模式的分类和功能