PyTorch 7.保存和加载pytorch模型的方法
保存和加载模型
python的对象都可以通过torch.save和torch.load函数进行保存和加载
x1 = {"d":"df","dd":"ddf"}
torch.save(x1,'a1.pt')
x2 = torch.load('a1.pt')
下面来谈模型的state_dict(),该函数返回模型的所有参数
class MLP(nn.Module):
def __init__(self):
super(MLP,self).__init__()
self.hidden = nn.Linear(3,2)
self.act = nn.ReLU()
self.output = nn.Linear(2,1)
def forward(self,x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP()
net.state_dict()
输出
OrderedDict([('hidden.weight',
tensor([[-0.4195, 0.2609, 0.4325],
[-0.4031, 0.2078, 0.2077]])),
('hidden.bias', tensor([ 0.0755, -0.1408])),
('output.weight', tensor([[0.2473, 0.6614]])),
('output.bias', tensor([0.6191]))])