pytorch使用多GPU

直接修改dict的key当然也是可以的,不会影响模型。

但是逻辑上,事实上DataParallel也是一个Pytorch的nn.Module,只是这个类其中有一个module的变量用来保存传入的实际模型。

nn.DataParallel(m)

这句返回的已经不是原始的m了,而是一个DataParallel,原始的m保存在DataParallel的module变量里面。

 

所以,逻辑上有两个方法:

  1. 保存的时候直接取出原始的m:
torch.save(m.module.state_dict(), path)

2. 或者载入的时候用一个DataParallel载入,再取出原始模型:

m=nn.DataParallel(Resnet18(), device_ids=[0,1,2])
m.load_state_dict(torch.load(path))
m=m.module

这样逻辑上更好看一点。



pytorch使用多GPU

上一篇:根据浏览器访问页面获取Ip


下一篇:ModalResult是指一?模式窗体(form.showmodal)的返回值