其他坑一些别人踩过的坑
1.Broadcast function not implemented for CPU tensors
这是因为model不在gpu上所致。model.to(device)。DataParallel会对模型参数所在的gpu位置进行检查,见源码
DataParallel是每次forward时对模型进行broadcast,当模型不在第一个GPU上时,就会出现错误
https://github.com/pytorch/pytorch/issues/17065
2.all tensors must be on devices[0]
这是因为model不在DataParallel设置的ids中的第一个上。输入的变量可以随便放在一个GPU上,而模型必须在你设置DataParallel的ids中的第一个
3. 多GPU模型转换到cpu上
通过DataParallel包装的model会再加一层module。所以state_dict会多一个module前缀。假设net1 是通过DataParallel包装的模型Net的实例,我们要把它装换到cpu上。方法就是重新建一个对象,把参数迁移过去
state_dict = net.module.state_dict()
net = Net()
net.load_state_dict(state_dict)
4.使用DataParallel包装模型时,如果gpu>1且模型是多输出的,会出现梯度为None的错误
参数的梯度永远都是None,这个是pytorch 1.0 的一个bug 或见FloWaveNet issues,pytorch issues 15716
据说是因为引用计数的问题出的bug。所以这里的一个解决方案是上面链接提供的方法,我将其修改为可供多次输入。将forward过程分发的模型和output保留下来。backward之后再清除掉(丢除引用)。这里用一个list保存是因为可能一个模型要经过多次输入计算loss。如果仅仅一次输入,那么就只需要保存一次。之后要记得调用reset将保留的引用清空,不然的话,全都存着,gpu内存暴涨。
class DataParallelFix(nn.DataParallel):
"""
Temporary workaround for https://github.com/pytorch/pytorch/issues/15716.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._replicas = []
self._outputs = []
self.src_device_obj = torch.device("cuda:{}".format(self.device_ids[0]))
def reset(self):
self._replicas = []
self._outputs = []
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
"module must have its parameters and buffers "
"on device {} (device_ids[0]) but found one of "
"them on device: {}".format(self.src_device_obj,
t.device))
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
_replicas = self.replicate(self.module,
self.device_ids[:len(inputs)])
_outputs = self.parallel_apply(_replicas, inputs, kwargs)
self._replicas.append(_replicas)
self._outputs.append(_outputs)
return self.gather(_outputs, self.output_device)
5. 或者training loss忽高忽低或者不下降。loss变为nan
一个原因是学习率太高
6. ByteTensor和LongTensor不会自动转换成FloatTensor
所以一个LongTensor除以一个数会只保留整数部分
比如
((out == label).sum() / float(batch_size)).item()
#结果是0
应该改成
((out == label).sum().item() / batch_size)
或者
((out == label).sum().float() / batch_size).item()
待补充