Torch.spmm只支持 sparse 在前,dense 在后的矩阵乘法,两个sparse相乘或者dense在前的乘法不支持,当然两个dense矩阵相乘是支持的。
import torch
if __name__ == '__main__':
indices = torch.tensor([[0,1],
[0,1]])
values = torch.tensor([2,3])
shape = torch.Size((2,2))
s = torch.sparse.FloatTensor(indices,values,shape)
print(s)
d = torch.tensor([[1,2],
[3,4]])
e = torch.tensor([[1, 2],
[3, 4]])
# print(d)
#
print(torch.spmm(s,d))
错误的情况会报错:
RuntimeError: sparse tensors do not have strides