训练网络的保存和提取

import torch
import torch.nn.functional as  F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'




x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2 * torch.rand(x.size())
x, y = Variable(x), Variable(y)

#保存构建好的网络
def save():
    net = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )

    optimizer = torch.optim.SGD(net.parameters(), lr=0.5)  # 选择SGD梯度下降算法
    loss_func = torch.nn.MSELoss()  # 使用均方误差

    for t in range(500):  # 训练500次
        prediction = net(x)  # 获得prediction
        loss = loss_func(prediction, y)  # 获得误差

        optimizer.zero_grad()  # 清除上一步的梯度(因为使用的位SGD固每次下降只需要当前的梯度)
        loss.backward()  # 反向计算梯度
        optimizer.step()  # 下降一步

    torch.save(net,r'C:\Users\Y_ch\Desktop\torch_test\net.pkl')               #entire net
    torch.save(net.state_dict(),r'C:\Users\Y_ch\Desktop\torch_test\net1.pkl') #only parameters


#完整的提取网络
def restore_net_entire(): #load entire net
    re_net = torch.load(r'C:\Users\Y_ch\Desktop\torch_test\net.pkl')
    prediction = re_net(x)
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.show()

#提取网络的参数
def restore_net_para():
    re_net = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    re_net.load_state_dict(torch.load(r'C:\Users\Y_ch\Desktop\torch_test\net1.pkl'))


save()
restore_net_entire()

 

上一篇:CVPR2021 MotionRNN: A Flexible Model for Video Prediction with Spacetime-Varying Motions


下一篇:人工智能教程 - 专业选修课程4.3.11 - 复杂结构数据挖掘 5.PageRank技术