PyTorch 7.保存和加载pytorch模型的方法

PyTorch 7.保存和加载pytorch模型的方法

保存和加载模型

python的对象都可以通过torch.save和torch.load函数进行保存和加载

x1 = {"d":"df","dd":"ddf"}
torch.save(x1,'a1.pt')
x2 = torch.load('a1.pt')

下面来谈模型的state_dict(),该函数返回模型的所有参数

class MLP(nn.Module):
	def __init__(self):
		super(MLP,self).__init__()
		self.hidden = nn.Linear(3,2)
		self.act = nn.ReLU()
		self.output = nn.Linear(2,1)
	def forward(self,x):
		a = self.act(self.hidden(x))
		return self.output(a)
net = MLP()
net.state_dict()

输出

OrderedDict([('hidden.weight',
              tensor([[-0.4195,  0.2609,  0.4325],
                      [-0.4031,  0.2078,  0.2077]])),
             ('hidden.bias', tensor([ 0.0755, -0.1408])),
             ('output.weight', tensor([[0.2473, 0.6614]])),
             ('output.bias', tensor([0.6191]))])
上一篇:2021最新的NVIDIA显卡排行榜前十


下一篇:Failed to initialize NVML: Driver/library version mismatch(已解决)