学习总结
(1)loss
实际在构建计算图,每次运行完后计算图就释放了。
(2)Tensor的Grad也是一个Tensor。更新权重w.data = w.data - 0.01 * w.grad.data
的0.01乘那坨其实是在建立计算图,而我们这里要乘0.01 * grad.data
,这样是不会建立计算图的(并不希望修改权重w,后面还有求梯度)。
(3)下面的w.grad.item()
是直接把w.grad
的数值取出,变成一个标量(也是为了防止产生计算图)。总之,牢记权重更新过程中要使用data
。
文章目录
一、基础回顾
1.1 正向传递
1.2 反向传播
1.3 举栗子
现在以 f = x ⋅ ω f=x \cdot \omega f=x⋅ω 为例:
(1)正向传递
(2)反向传播
注意虽然这里的
∂
L
∂
x
=
∂
L
∂
z
⋅
∂
z
∂
x
\frac{\partial L}{\partial x}=\frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial x}
∂x∂L=∂z∂L⋅∂x∂z不求也可以,但是在pytorch是会求出来的(因为如果是多层,则需要用到该中间层求得的的梯度)。
二、计算图
2.1 线性模型的计算图
练习:
三、代码实战
# -*- coding: utf-8 -*-
"""
Created on Sun Oct 17 19:39:32 2021
@author: 86493
"""
import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = torch.Tensor([1.0])
w.requires_grad = True
# 向前传递
def forward(x):
return x * w
# 这里使用SGD
def loss(x, y):
y_pred = forward(x)
return (y_pred - y) ** 2
print("predict (before training)", 4,
forward(4).item())
# 训练过程,SGD
for epoch in range(100):
for x, y in zip(x_data, y_data):
# 向前传播,计算loss
l = loss(x, y)
# 计算requires_grad为true的tensor的梯度
l.backward()
print('\tgrad:', x, y, w.grad.item())
w.data = w.data - 0.01 * w.grad.data
# 反向传播后grad会被重复计算,所以记得清零梯度
w.grad.data.zero_()
print("progress:", epoch, l.item())
print("predict (after training)", 4,
forward(4).item())
注意:
(1)loss
实际在构建计算图,每次运行完后计算图就释放了。
(2)Tensor的Grad也是一个Tensor。更新权重w.data = w.data - 0.01 * w.grad.data
的0.01乘那坨其实是在建立计算图,而我们这里要乘0.01 * grad.data
,这样是不会建立计算图的(并不希望修改权重w,后面还有求梯度)。
(3)w.grad.item()
是直接把w.grad
的数值取出,变成一个标量(也是为了防止产生计算图)。总之,牢记权重更新过程中要使用data
。
(4)如果不像上面计算一个样本的loss,想算所有样本的loss(cost),然后就加上sum += l
,注意此时sum是关于张量
l
l
l 的一个计算图,又未对sum
做backward
操作,随着l
越加越多会导致内存爆炸。
正确做法:sum += l.item()
,别把损失直接加到sum里面。
Tensor在做加法运算时会构建计算图
(5)backward
后的梯度一定要记得清零w.grad.data.zero()
。
(6)训练过程:先计算loss损失值,然后backward
反向传播,现在就有了梯度了。通过梯度下降更新参数:
四、作业
Reference
(1)PyTorch 深度学习实践 第10讲,刘二系列
(2)b站视频:https://www.bilibili.com/video/BV1Y7411d7Ys?p=10
(3)官方文档:https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv2d
(4)吴恩达网易云课程:https://study.163.com/my#/smarts
(5)刘洪普老师博客:https://liuii.github.io/
(6)某同学的笔记:http://biranda.top/archives/page/2/