Softmax与Cross-entropy的求导

引言

在多分类问题中,一般会把输出结果传入到softmax函数中,得到最终结果。并且用交叉熵作为损失函数。本来就来分析下以交叉熵为损失函数的情况下,softmax如何求导。

对softmax求导

softmax函数为:

yi=ezik=1Kezk y_i = \frac{e^{z_i}}{\sum_{k=1}^K e^{z_k}} yi​=∑k=1K​ezk​ezi​​

这里KKK是类别的总数,接下来求yiy_iyi​对某个输出zjz_jzj​的导数,
yizj=ezik=1Kezkzj \frac{\partial y_i}{\partial z_j} = \frac{\partial \frac{e^{z_i}}{\sum_{k=1}^K e^{z_k}}}{\partial z_j} ∂zj​∂yi​​=∂zj​∂∑k=1K​ezk​ezi​​​

这里要分两种情况,分别是i=ji=ji=j与iji \neq ji​=j。当i=ji=ji=j时,ezie^{z_i}ezi​对zjz_jzj​的导数为ezie^{z_i}ezi​,否则当iji \neq ji​=j时,导数为000。

i=ji = ji=j,
yizj=ezik=1Kezkeziezj(k=1mezk)2=ezik=1mezkezik=1mezkezjk=1mezk=yiyi2=yi(1yi) \frac{\partial y_i}{\partial z_j} = \frac{e^{z_i}\cdot \sum_{k=1}^K e^{z_k} - e^{z_i} \cdot e^{z_j} }{(\sum_{k=1}^m e^{z_k})^2} \\ = \frac{e^{z_i}}{\sum_{k=1}^m e^{z_k}} - \frac{e^{z_i}}{\sum_{k=1}^m e^{z_k}} \cdot \frac{e^{z_j}}{\sum_{k=1}^m e^{z_k}} \\ = y_i - y_i^2 = y_i(1 - y_i) ∂zj​∂yi​​=(∑k=1m​ezk​)2ezi​⋅∑k=1K​ezk​−ezi​⋅ezj​​=∑k=1m​ezk​ezi​​−∑k=1m​ezk​ezi​​⋅∑k=1m​ezk​ezj​​=yi​−yi2​=yi​(1−yi​)

iji \neq ji​=j,
yizj=0k=1Kezkeziezj(k=1mezk)2=ezik=1mezkezjk=1mezk=yiyj \frac{\partial y_i}{\partial z_j} = \frac{0 \cdot \sum_{k=1}^K e^{z_k} - e^{z_i} \cdot e^{z_j}}{(\sum_{k=1}^m e^{z_k})^2} \\ = - \frac{e^{z_i}}{\sum_{k=1}^m e^{z_k}} \cdot \frac{e^{z_j}}{\sum_{k=1}^m e^{z_k}} \\ = - y_i y_j ∂zj​∂yi​​=(∑k=1m​ezk​)20⋅∑k=1K​ezk​−ezi​⋅ezj​​=−∑k=1m​ezk​ezi​​⋅∑k=1m​ezk​ezj​​=−yi​yj​

对cross-entropy求导

损失函数LLL为:

L=ky^klogyk L = -\sum_k \hat y_k \log y_k L=−k∑​y^​k​logyk​

其中y^k\hat y_ky^​k​是真实类别,相当于一个常数,接下来求LLL对zjz_jzj​的导数

Lzj=(ky^klogyk)zj=(ky^klogyk)ykykzj=ky^k1ykykzj=(y^kyk(1yk)1yk)k=jkjy^k1yk(ykyj)=y^j(1yj)kjy^k(yj)=y^j+y^jyj+kjy^k(yj)=y^j+ky^k(yj)=y^j+yj=yjy^j \frac{\partial L}{\partial z_j} = \frac{\partial -(\sum_k \hat y_k \log y_k)}{z_j} = \frac{\partial -(\sum_k \hat y_k \log y_k)}{\partial y_k} \frac{\partial y_k}{\partial z_j} \\ = -\sum_k \hat y_k \frac{1}{y_k} \frac{\partial y_k}{z_j} \\ = \left(-\hat y_k \cdot y_k(1 - y_k) \frac{1}{y_k} \right)_{k=j} - \sum_{k \neq j} \hat y_k \frac{1}{y_k} (-y_ky_j) \\ = -\hat y_j (1 - y_j) - \sum_{k \neq j} \hat y_k (-y_j) \\ = -\hat y_j + \hat y_j y_j + \sum_{k \neq j} \hat y_k (y_j) \\ = -\hat y_j + \sum_{k} \hat y_k (y_j) \\ = -\hat y_j + y_j \\ = y_j -\hat y_j ∂zj​∂L​=zj​∂−(∑k​y^​k​logyk​)​=∂yk​∂−(∑k​y^​k​logyk​)​∂zj​∂yk​​=−k∑​y^​k​yk​1​zj​∂yk​​=(−y^​k​⋅yk​(1−yk​)yk​1​)k=j​−k​=j∑​y^​k​yk​1​(−yk​yj​)=−y^​j​(1−yj​)−k​=j∑​y^​k​(−yj​)=−y^​j​+y^​j​yj​+k​=j∑​y^​k​(yj​)=−y^​j​+k∑​y^​k​(yj​)=−y^​j​+yj​=yj​−y^​j​

这里用到了ky^k=1\sum_{k} \hat y_k = 1∑k​y^​k​=1

可以看到,求导结果非常简单,如果不推倒都不敢信。

上一篇:一起学习朴素贝叶斯


下一篇:linux 混杂设备驱动之adc驱动