pytorch F.cross_entropy(x,y)理解

F.cross_entropy(x,y)

 1 x = np.array([[1, 2,3,4,5],
 2              [1, 2,3,4,5],
 3              [1, 2,3,4,5]]).astype(np.float32)
 4 y = np.array([1, 1, 0])
 5 x = torch.from_numpy(x)
 6 y = torch.from_numpy(y).long()
 7 
 8 soft_out = F.softmax(x,dim=1)
 9 log_soft_out = torch.log(soft_out)
10 loss = F.nll_loss(log_soft_out, y)
11 print(soft_out)
12 print(log_soft_out)
13 print(loss)
14   
15 loss = F.cross_entropy(x, y)
16 print(loss)

结果:

softmax:

tensor([[0.0117, 0.0317, 0.0861, 0.2341, 0.6364],
[0.0117, 0.0317, 0.0861, 0.2341, 0.6364],
[0.0117, 0.0317, 0.0861, 0.2341, 0.6364]])


tensor([[-4.4519, -3.4519, -2.4519, -1.4519, -0.4519],
[-4.4519, -3.4519, -2.4519, -1.4519, -0.4519],
[-4.4519, -3.4519, -2.4519, -1.4519, -0.4519]])


tensor(3.7852)
tensor(3.7852)

结果分析:

F.softmax(x,dim=1):一行和为1 sum([0.0117, 0.0317, 0.0861, 0.2341, 0.6364])=1
softmax函数公式
pytorch F.cross_entropy(x,y)理解

 

torch.log(soft_out):对softmax的结果进行取对数
a =pow(math.e,1)/(pow(math.e,1)+pow(math.e,2)+pow(math.e,3)+pow(math.e,4)+pow(math.e,5)) # 0.011656230956039609近似0.0117
print(math.log(0.011656230956039609)) # -4.4519
F.nll_loss(log_soft_out, y):对取对数的结果,根据y的值,(y值是索引),找到对应的值,黄色部分,各自取相反数再相加,求平均
(3.4519+3.4519+4.4519)/3 = 3.7852
所以:
cross_entropy函数:softmax->log->nll_loss
参考链接:
https://blog.csdn.net/qq_22210253/article/details/85229988
https://blog.csdn.net/wuliBob/article/details/104119616
https://blog.csdn.net/weixin_38314865/article/details/104487587?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.nonecase&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.nonecase


 

 

 

上一篇:Mysql常用sql语句(15)- cross join 交叉连接


下一篇:《xDeepFM:名副其实的 ”Deep” Factorization Machine》