pytorch学习001- -如何保存模型

保存和加载模型

只保存模型的参数

保存

torch.save(model.state_dict(),'xxx.pth')

加载

model = net()	#首先要先定义网络模型
state_dict = torch.load('xxx.pth')	# 读取pth文件中的参数
model.load_state_dict(state_dict['model'])	#将参数导入模型

这种方法操作比较麻烦,但是比较节省内存。

official example

class MyModule(torch.nn.Module):


m = MyModule()
m.state_dict()

torch.save(m.state_dict(), 'mymodule.pt')
m_state_dict = torch.load('mymodule.pt')
new_m = MyModule()
new_m.load_state_dict(m_state_dict)

保存整个模型

保存

torch.save(net, 'xxx.pt')

加载

test = torch.load('xxx.pt')	#注意其中pt文件的路径

这种方式是将整个的网络模型进行保存,使用不便,但是加载方便,适合于简单测试。

官方文档

上一篇:#我上了个假大学之电路#001 绪论,模型,定律


下一篇:学生表如下