Pytorch加载模型时报错:
此时的加载部分代码为:
model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
model.load_state_dict(torch.load(args.weights))
model.cuda()
修改第二行代码为:
model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
model.load_state_dict(torch.load(args.weights)['state_dict'])
model.cuda()
然后模型就可以正常加载了!