【PyTorch】自定义损失函数

  • 通常损失函数可以直接使用 P y T o r c h \rm PyTorch PyTorch 封装好的损失函数,例如 C r o s s E n t r p y L o s s ( ) , n l l _ l o s s ( ) \rm CrossEntrpyLoss(),nll\_loss() CrossEntrpyLoss(),nll_loss() 等,倘如自定义损失函数,需要注意梯度清零问题
  • 以线性模型 y = w x + b y=wx+b y=wx+b 为例,参数为 w , b w,b w,b,自定义平凡损失函数如下:
# In[Loss]
def loss_fn(y,y_pred):
    loss = (y_pred-y).pow(2).sum()
    for param in [w,b]:
        if not param.grad is None: 
            param.grad.data.zero_()
    loss.backward()
    return loss.data
  • 因为损失函数 l o s s _ f n ( ) \rm loss\_fn() loss_fn() 会不止一次地被调用,因此需要通过 g r a d . d a t a . z e r o ( ) \rm grad.data.zero_() grad.data.zero(​) 方法清理上一次的梯度值。第一次调用 l o s s _ f n \rm loss\_fn loss_fn 时,梯度尚且为初始状态 N o n e \rm None None,因此 i f \rm if if 条件如上。
上一篇:线性回归——pytorch实现


下一篇:线性回归pytorch实现笔记