import torch as t
x = t.ones(1, requires_grad=True)
w = t.rand(1, requires_grad=True)
y = x * w
# y依赖于w,而w.requires_grad = True
x.requires_grad, w.requires_grad, y.requires_grad
(True, True, True)
x.grad_fn,y.grad_fn
(None, <MulBackward0 at 0x2f93ba176d8>)
x.is_leaf, w.is_leaf, y.is_leaf
(True, True, False)
y.grad_fn.next_functions
((<AccumulateGrad at 0x2f93ba174a8>, 0),
(<AccumulateGrad at 0x2f93ba17c50>, 0))
a = t.ones(3,4,requires_grad=True)
b = t.ones(3,4,requires_grad=True)
c = a * b
a.data # 还是一个tensor
# 第一种方法:使用grad获取中间变量的梯度
x = t.ones(3, requires_grad=True)
w = t.rand(3, requires_grad=True)
y = x * w
z = y.sum()
# z对y的梯度,隐式调用backward()
t.autograd.grad(z, y)
(tensor([1., 1., 1.]),)
# 第二种方法:使用hook
# hook是一个函数,输入是梯度,不应该有返回值
def variable_hook(grad):
print('y的梯度:',grad)
x = t.ones(3, requires_grad=True)
w = t.rand(3, requires_grad=True)
y = x * w
# 注册hook
hook_handle = y.register_hook(variable_hook)
z = y.sum()
z.backward()
# 除非你每次都要用hook,否则用完之后记得移除hook
hook_handle.remove()
y的梯度: tensor([1., 1., 1.])
例程 使用Variable 实现线性回归
import torch as t
from matplotlib import pyplot as plt
from IPython import display
import numpy as np
# 设置随机数种子,为了在不同人电脑上运行时下面的输出一致
t.manual_seed(1000)
def get_fake_data(batch_size):
''' 产生随机数据:y = x*2 + 3,加上了一些噪声'''
x = t.rand(batch_size,1) * 5
y = x * 2 + 3 + t.randn(batch_size, 1)
return x, y
# 来看看产生x-y分布是什么样的
x, y = get_fake_data()
plt.scatter(x.squeeze().numpy(), y.squeeze().numpy())
<matplotlib.collections.PathCollection at 0x2f93f8eeb70>
#随即初始化参数
w = t.rand(1,1,requires_grad=True)
b = t.zeros(1,1,requires_grad=True)
losses = np.zeros(500)
lr = 0.005
for ii in range(500):
x,y = get_fake_data(batch_size=32)
#forward
y_pred = x.mm(w) + b.expand_as(y)
loss = 0.5*(y_pred - y)**2
loss = loss.sum()
losses[ii]=loss.item()
#backward:
loss.backward()
#更新参数
w.data.sub_(lr*w.grad.data)
b.data.sub_(lr*b.grad.data)
#梯度清零
w.grad.data.zero_()
b.grad.data.zero_()
if ii%50 == 0:
#plot
display.clear_output(wait=True)
x = t.arange(0,6).view(-1,1).float()
y = x.mm(w.data) + b.data.expand_as(x)
plt.plot(x.numpy(),y.numpy())#predict
x2,y2 = get_fake_data(batch_size=20)
plt.scatter(x2.numpy(),y2.numpy())#true data
plt.xlim(0,5)
plt.ylim(0,13)
plt.show()
plt.pause(0.5)
print(w.item(),b.item())