每天讲解一点PyTorch 【15】model.load_state_dict torch.load

今天我们讲解:

state_dict = torch.load('checkpoint.pt')
#或者
state_dict = torch.load('checkpoint.pth') #torch.load加载**模型参数**
model.load_state_dict(state_dict) #把模型参数加载到模型中

model.cuda()
model.eval() #model.eval()关闭Batch Normalization和Dropout层
#加载模型结构和模型参数
model = torch.load(path)
output = model(x)
上一篇:Pytorch基础操作 —— 3.保存和加载Torch模型和参数


下一篇:介绍禁止输入三种不让input输入中文的方法