前言
我们训练好的网络,怎么保存和提取呢?
总不可以一直不关闭电脑吧,训练到一半,想结束到明天再来训练,这就需要进行网络的保存和提取了。
本文以前面博客3-pytorch搭建一个简单的前馈全连接层网络(回归问题)的网络进行网络的保存和提取,建议先看完上面博客再来看本博客。
一、生成训练数据
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
# 生成数据(fake data)
x = torch.linspace(-1,1,100).reshape(-1,1)
# 加上点噪声
y = x.pow(2) + 0.2*torch.rand(x.shape)
# 可视化一下数据
plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()
输出:
二、网络保存
def save():
net1 = torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
optimizer = torch.optim.SGD(net1.parameters(),lr=0.5)
loss_func = torch.nn.MSELoss()
for t in range(100):
prediction = net1(x)
loss = loss_func(prediction,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 下面介绍两种不同的保存方法,方法二可能运行速度要快点
# 保存整个网络的所有
torch.save(net1, 'net.pkl')
# 保存好网络的参数
torch.save(net1.state_dict(),'net_params.pkl')
# plot result
plt.figure(1,figsize=(10,3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
【注】:保存整个网络还是保存网络参数,个人建议仅保存参数,这个速度更快。
三、网络提取
def restore_net():
net2 = torch.load('net.pkl')
prediction = net2(x)
# plot result
plt.subplot(132)
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
def restore_params():
# 如果只是保留参数的情况,提取时需要再次定义相同网络才行
net3 = torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)
# plot result
plt.subplot(133)
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
四、对保存网络提取进行结果展示
save()
restore_net()
restore_params()