pytorch 之 保存不同形式的预训练模型

注意,后缀.pt和.pth似乎没什么区别

保存时即可以保存整个模型也可以只保存参数,还可以构建新字典重新保存,这也就对应了在读取时需要做不同的处理,我们在加载的时候load_state_dict函数的参数就是OrderedDict类型的参数,这里给出了四种不同保存方式及其读取获得OrderedDict的方式。
1.保存

# coding=gbk
import torch 
import torch.nn as nn
class MLP_(nn.Module):
    def __init__(self):
        super(MLP_, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

net = MLP_()

#保存整个模型

torch.save(net, 'a1.pt')

all_model = {'model':net} #为模型部分添加键值,这样如果想要保存优化器参数的,可以向字典中加入新值
torch.save(all_model, 'a2.pt')
#只保存参数

torch.save(net.state_dict(),'a3.pt')
all_states = {'state_dict': net.state_dict()} #为模型参数部分添加键值,这样如果想要保存优化器参数的,可以向字典中加入新值
torch.save(all_states, 'a4.pt')

2.加载

# coding=gbk
import torch
from save import MLP_
if __name__ == "__main__":
    with torch.no_grad():
   
        a1 = 'a1.pt'
        a2 = 'a2.pt'
        a3 = 'a3.pt'
        a4 = 'a4.pt'
        a1_ = torch.load(a1)
        print(a1_.state_dict())
        a2_ = torch.load(a2)['model'] #通过键值选取对应值
        print(a2_.state_dict())
        a3_ = torch.load(a3)
        print(a3_)
        a4_ = torch.load(a4)['state_dict'] #通过键值选取对应值
        print(a4_)

参考:https://zhuanlan.zhihu.com/p/94971100

上一篇:#pt#课堂笔记_5_pt实现线性回归


下一篇:想学 javaJDBC ? 来这里我手把手教你