import torch
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2 * torch.rand(x.size())
x, y = Variable(x), Variable(y)
#保存构建好的网络
def save():
net = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net.parameters(), lr=0.5) # 选择SGD梯度下降算法
loss_func = torch.nn.MSELoss() # 使用均方误差
for t in range(500): # 训练500次
prediction = net(x) # 获得prediction
loss = loss_func(prediction, y) # 获得误差
optimizer.zero_grad() # 清除上一步的梯度(因为使用的位SGD固每次下降只需要当前的梯度)
loss.backward() # 反向计算梯度
optimizer.step() # 下降一步
torch.save(net,r'C:\Users\Y_ch\Desktop\torch_test\net.pkl') #entire net
torch.save(net.state_dict(),r'C:\Users\Y_ch\Desktop\torch_test\net1.pkl') #only parameters
#完整的提取网络
def restore_net_entire(): #load entire net
re_net = torch.load(r'C:\Users\Y_ch\Desktop\torch_test\net.pkl')
prediction = re_net(x)
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show()
#提取网络的参数
def restore_net_para():
re_net = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
re_net.load_state_dict(torch.load(r'C:\Users\Y_ch\Desktop\torch_test\net1.pkl'))
save()
restore_net_entire()