用torch进行矩阵运算
下面我主要区别以下几个函数:
其中,torch.mm()、torch.matmul()、torch.mul() 是我们比较常用的,但是用torch.einsum() 可以实现上述三个函数的任何一个的功能,下面我一一介绍这些函数的用法和适用情况。
这个函数就是实现两个张量之间元素对元素的运算,也就是对应元素相乘,这个函数可以使用broadcast广播操作,因此a与b的维度可以不一致。
实现矩阵乘法运算,必须满足矩阵乘法运算中对a,b的维度要求,即没有broadcast广播操作。
实现矩阵乘法运算,有广播操作,所以不需要满足矩阵乘法操作中维度的严格要求。
# 用torch.esinum()实现torch.mul(a,b)
b1 = 3 * torch.ones(5, 4, 2)
b = 2 * torch.ones(5, 4, 2)
c = torch.einsum('kji,kji->kji', b, b1)
print((c == (torch.mul(b, b1))))
print(c, '---------c------------')
# 用torch.esinum()实现矩阵乘法torch.matmul()
b1 = 3 * torch.ones(3, 4)
b = 2 * torch.ones(5, 4, 2)
c = torch.einsum('kji,mj->kmi', b, b1)
# 用torch.esinum()实现矩阵乘法torch.mm()
b1 = 3 * torch.ones(3, 4)
b = 2 * torch.ones(4, 2)
c = torch.einsum('ji,mj->mi', b, b1)
具体怎么用的感觉不知道应该怎么具体描述,多用几次torch.esinum()就会了,这个函数的使用范围很广,除了替代矩阵之间的运算还有其他的一些用法。