【PyTorch】Caught RuntimeError in DataLoader worker process 0和invalid argument 0: Sizes of tensors mus

报错如下:

Traceback (most recent call last):
  File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/tqdm/std.py", line 1178, in __iter__
    for obj in iterable:
  File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 819, in __next__
    return self._process_data(data)
  File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 846, in _process_data
    data.reraise()
  File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/_utils.py", line 369, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 75, in default_collate
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 75, in <dictcomp>
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 65, in default_collate
    return default_collate([torch.as_tensor(b) for b in batch])
  File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 56, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 8 and 16 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689

__getitem__函数出确实得到数据,因此问题在torch.utils.data.DataLoader

分析

其实这是2个错误:

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 8 and 16 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689

提示数据维度不统一,跳到File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 56, in default_collate return torch.stack(batch, 0, out=out)源文件处:

  if isinstance(elem, torch.Tensor):
   out = None
   if torch.utils.data.get_worker_info() is not None:
       # If we're in a background process, concatenate directly into a
       # shared memory tensor to avoid an extra copy
       numel = sum([x.numel() for x in batch])
       storage = elem.storage()._new_shared(numel)
       out = elem.new(storage)
   return torch.stack(batch, 0, out=out)

可以发现dataloader在最后需要做合并,如果设置了batchsize,那么这里就是进行batch合并,如果维度不统一,那么就会报错了。

至于另一错误,是启用多线程(num_workers!=0)提示哪个线程有问题,因为batch合并时维度不一样,导致第一个线程就挂了(worker process 0),所以会提示RuntimeError: Caught RuntimeError in DataLoader worker process 0.

解决方法

既然维度不统一,那确保维度一样就好了,可以预先设置一个足够大的array或者tensor,没有填充的部分设置标记,读取到数据时根据标记确定有效数据。

上一篇:2021-07-18


下一篇:spark连接mysql数据库