pytroch中将pth转化为pt文件

import torch
import torchvision



model = torchvision.models.resnet18()
state_dict = torch.load("/home/xu/workspace/eff/fire_extinguisher/efficientdet-pytorch-master/logs/Epoch100-Total_Loss0.0801-Val_Loss0.1916.pth")
model.load_state_dict(state_dict,False)
model.eval()
example = torch.rand(1, 3, 320, 480)#生成一个随机输入维度的输入
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save('/home/xu/workspace/eff/fire_extinguisher/efficientdet-pytorch-master/mobilenet_v2.pt')

 

上一篇:两个pytorch小bug


下一篇:opencv对图片画框写文字