交叉损失熵函数的表达式为:
假设一个三分类问题,该问题的标签为:
Person | Dog | Cat |
0 | 1 | 2 |
将一张狗的图片输入神经网络,得到输出[0.1,0.6,0.3]
则有:
我们编程来验证一下:
import torch
from torch.nn import CrossEntropyLoss
x = torch.tensor([0.1,0.6,0.3])
y = torch.tensor([1])
x = torch.reshape(x,(1,3))
loss_cross = CrossEntropyLoss()
result_cross = loss_cross(x,y)
print(result_cross)
工作台输出为:
tensor(0.8533)
计算器输出为: