一、前言
1、保存训练的模型以备将来在各种环境中使用
2、当运行一个耗时较长的训练过程时,最佳的做法是定期保存中间结果(检查点),以确保在服务器电源被不小心断掉时不会损失几天的计算结果
二、加载和保存张量
1、对于单个张量,我们可以直接调用load和save函数分别读写它们
import torch from torch import nn from torch.nn import functional as F x = torch.arange(4) print(x) # 将张量通过save函数保存在文件x-file中 torch.save(x, 'x-file') #输出结果 tensor([0, 1, 2, 3])
2、将存储在文件中的数据读回内存
# 使用load()函数将张量从文件中读取出来 x2 = torch.load('x-file') x2 #print(torch.load('x-file')) #输出结果 tensor([0, 1, 2, 3])
3、存储一个张量列表,然后把它们读回内存
# 存储和读取多个张量 y = torch.zeros(4) print(x) print(y) # 存储列表 torch.save([x, y], 'x-files') # 先返回X,再返回Y x2, y2 = torch.load('x-files') (x2, y2) #输出结果 tensor([0, 1, 2, 3]) tensor([0., 0., 0., 0.]) (tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
4、写入或读取从字符串映射到张量的字典
print(x) print(y) # 从字典中读取张量 mydict = {'x': x, 'y': y} torch.save(mydict, 'mydict') mydict2 = torch.load('mydict') mydict2 #输出结果 tensor([0, 1, 2, 3]) tensor([0., 0., 0., 0.]) {'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
三、加载和保存模型参数
1、模型无法保存,需要用代码生成结构。参数可以保存,从磁盘加载参数就行。
# 深度学习框架提供了内置函数来保存和加载整个网络-只是保存了参数 # 为了恢复模型,我们需要用代码生成结构,然后从磁盘加载参数 class MLP(nn.Module): def __init__(self): super().__init__() self.hidden = nn.Linear(20, 256) self.output = nn.Linear(256, 10) def forward(self, x): return self.output(F.relu(self.hidden(x))) net = MLP() print(net) X = torch.randn(size=(2, 20)) print(X) Y = net(X) print(Y) #输出结构 MLP( (hidden): Linear(in_features=20, out_features=256, bias=True) (output): Linear(in_features=256, out_features=10, bias=True) ) tensor([[-0.1273, 1.4412, 0.0323, 1.3365, -0.6067, -0.0327, -0.9841, -0.2114, -0.0667, -0.5265, 0.6789, -0.5793, 0.3214, -0.7539, 0.1319, -0.2952, 0.4497, -0.4045, 0.2847, 1.7031], [-0.2389, 0.3846, -1.1254, -0.3895, -1.5898, -0.2039, -0.6001, -0.3358, -1.2053, 0.3487, -1.5540, -0.6101, -0.9526, 1.0642, -0.4290, -1.4622, 0.4534, -1.7637, -0.3076, 1.5279]]) tensor([[-0.1711, -0.0050, 0.0361, 0.0926, 0.0374, -0.1191, -0.2689, -0.0940, 0.2355, 0.1106], [-0.6878, 0.0452, -0.1850, 0.2128, -0.2549, 0.3149, -0.3111, 0.1489, -0.0063, 0.2223]], grad_fn=<AddmmBackward>)
2、将模型的参数存储在一个叫做“mlp.params”的文件
# 模型的参数——权重和偏差 torch.save(net.state_dict(), 'mlp.params')
3、为了恢复模型,我们实例化了原始多层感知机模型的一个备份。我们没有随机初始化模型参数,而是直接读取文件中存储的参数
# 恢复模型 # 实例化一个模型(没有传入参数-初始化) clone = MLP() # 读取文件中存储的参数 clone.load_state_dict(torch.load('mlp.params')) clone.eval() #输出结果 MLP( (hidden): Linear(in_features=20, out_features=256, bias=True) (output): Linear(in_features=256, out_features=10, bias=True) )
4、由于两个实例具有相同的模型参数,在输入相同的X
时,两个实例的计算结果应该相同。
Y_clone = clone(X) Y_clone == Y # Y=net(X) #输出结果 tensor([[True, True, True, True, True, True, True, True, True, True], [True, True, True, True, True, True, True, True, True, True]])
四、小结
1、save和load函数可用于张量对象的文件读写。
2、我们可以通过参数字典保存和加载网络的全部参数。
3、保存结构必须在代码中完成,而不是在参数中完成