以下代码经常在Transformer的算法中见到:
q, k, v = qkv[0], qkv[1], qkv[2] # query, key, value tensor
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
其中涉及到a @ b操作和transpose操作
a = torch.Tensor([[1,2],[3,4]])
print(a)
b = torch.Tensor([[0.5,2],[0.5,0.5]])
print(b)
print(a@b)
输出:
tensor([[1., 2.],
[3., 4.]])
tensor([[0.5000, 2.0000],
[0.5000, 0.5000]])
tensor([[1.5000, 3.0000],
[3.5000, 8.0000]])
import torch
x=torch.randn(12,3,10,20)
y=torch.randn(20,30)
z=x@y
print(z.shape)
输出结果:
torch.Size([12, 3, 10, 30])
从以上结果可以发现,默认以最后两维进行矩阵乘法运算
transpose(-2, -1) 表示将 k 的最后两维进行转置(交换位置)
import torch
q = torch.randn(125,2,343,16)
k = torch.randn(125,2,343,16)
attn = q @ k.transpose(-2, -1)
print(attn.shape)
输出:
torch.Size([125, 2, 343, 343])