保存和加载模型
只保存模型的参数
保存
torch.save(model.state_dict(),'xxx.pth')
加载
model = net() #首先要先定义网络模型
state_dict = torch.load('xxx.pth') # 读取pth文件中的参数
model.load_state_dict(state_dict['model']) #将参数导入模型
这种方法操作比较麻烦,但是比较节省内存。
official example
class MyModule(torch.nn.Module):
m = MyModule()
m.state_dict()
torch.save(m.state_dict(), 'mymodule.pt')
m_state_dict = torch.load('mymodule.pt')
new_m = MyModule()
new_m.load_state_dict(m_state_dict)
保存整个模型
保存
torch.save(net, 'xxx.pt')
加载
test = torch.load('xxx.pt') #注意其中pt文件的路径
这种方式是将整个的网络模型进行保存,使用不便,但是加载方便,适合于简单测试。