import torch as t
import torch.nn as nn
input = t.randint(1,5,(2,3,5)) # (2,3,5)的各个数字的介绍:2为batch_id,3为max_length,5是每个单词的维度
print(input)
print(input.dtype) # 查看input 整个的类型
dim1 = nn.Softmax(2)
res = dim1(input)
print(res)
执行结果如下:
应该是 tensor的 数据类型不对(不应该使用int类型),导致无法使用softmax() 函数。
3.解决办法不用使用 t.randint()
来生成tensor,因为其得到的数据类型是int,应该转用t.randn()
这样得到的是一个float 类型的数据,就可以计算softmax了。