注意,后缀.pt和.pth似乎没什么区别
保存时即可以保存整个模型也可以只保存参数,还可以构建新字典重新保存,这也就对应了在读取时需要做不同的处理,我们在加载的时候load_state_dict函数的参数就是OrderedDict类型的参数,这里给出了四种不同保存方式及其读取获得OrderedDict的方式。
1.保存
# coding=gbk
import torch
import torch.nn as nn
class MLP_(nn.Module):
def __init__(self):
super(MLP_, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP_()
#保存整个模型
torch.save(net, 'a1.pt')
all_model = {'model':net} #为模型部分添加键值,这样如果想要保存优化器参数的,可以向字典中加入新值
torch.save(all_model, 'a2.pt')
#只保存参数
torch.save(net.state_dict(),'a3.pt')
all_states = {'state_dict': net.state_dict()} #为模型参数部分添加键值,这样如果想要保存优化器参数的,可以向字典中加入新值
torch.save(all_states, 'a4.pt')
2.加载
# coding=gbk
import torch
from save import MLP_
if __name__ == "__main__":
with torch.no_grad():
a1 = 'a1.pt'
a2 = 'a2.pt'
a3 = 'a3.pt'
a4 = 'a4.pt'
a1_ = torch.load(a1)
print(a1_.state_dict())
a2_ = torch.load(a2)['model'] #通过键值选取对应值
print(a2_.state_dict())
a3_ = torch.load(a3)
print(a3_)
a4_ = torch.load(a4)['state_dict'] #通过键值选取对应值
print(a4_)
参考:https://zhuanlan.zhihu.com/p/94971100