y[np.arange(batch_size), t]的详细解析

   mini-batch版交叉熵误差

     《深度学习入门》一书中mini-batch版交叉熵误差的代码实现中,当监督数据是标签形式时的代码实现如下:


     def cross_entropy_error(y, t):

           if y.ndim == 1:

               t = t.reshape(1, t.size)

               y = y.reshape(1, y.size)

           batch_size = y.shape[0]

           return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size

       相信有一部份人肯定对最后一行代码(-np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size)有所疑惑,下面我来以我的想法详细讲解一下:

       首先,在代码 np.log(y[np.arange(batch_size), t] + 1e-7) 中,np.arnage(batch_size)会生成一个由0到batch_size-1的数组,而这个数组中的每一个值将会提供y[ ]矩阵的相对应的行坐标,

又因为t数组为非one-hot形式,所以t数组中的每一个元素就形如[2,4,7,1]这样的数组,其中的元素可以相对应于one-hot形式时的索引下标值形如[[0,0,1,0,0,0,0,0,0,0],[0,0,0,0,1,0,0,0,0,0],[0,0,0,0,0,0,0,1,0,0],[0,1,0,0,0,0,0,0,0,0]]

所以t数组中的所有元素都可以对应于y[ ]中的一个列下标,从而正确地取出输出数组y[ ]中对应于t[ ]中的一个解,又因为交叉熵误差的计算只计算t[ ]中不为零的值的总和,所以其他为零的值可以直接舍去不算,直接取不为零的值即可。

       然后,在代码np.log(y[np.arange(batch_size), t] + 1e-7)中,加上一个1e-7,是为了解决log(0)算出来的负无穷大的错误值。

       最后,-np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size再将所有log有效值相加取负值,再除以y数组元素的个数,求得平均损失函数值,继而求得正确的交叉熵误差。

 

上一篇:Linux openssl1.0.2k升级openssl1.1.1e版本教程


下一篇:浮点数二分