报错如下:
Traceback (most recent call last):
File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/tqdm/std.py", line 1178, in __iter__
for obj in iterable:
File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 819, in __next__
return self._process_data(data)
File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 846, in _process_data
data.reraise()
File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/_utils.py", line 369, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 75, in default_collate
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 75, in <dictcomp>
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 65, in default_collate
return default_collate([torch.as_tensor(b) for b in batch])
File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 56, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 8 and 16 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
在__getitem__
函数出确实得到数据,因此问题在torch.utils.data.DataLoader
处
分析
其实这是2个错误:
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 8 and 16 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
提示数据维度不统一,跳到File "/home/jiang/miniconda3/envs/Net/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 56, in default_collate return torch.stack(batch, 0, out=out)
源文件处:
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
可以发现dataloader在最后需要做合并,如果设置了batchsize,那么这里就是进行batch合并,如果维度不统一,那么就会报错了。
至于另一错误,是启用多线程(num_workers!=0)提示哪个线程有问题,因为batch合并时维度不一样,导致第一个线程就挂了(worker process 0),所以会提示RuntimeError: Caught RuntimeError in DataLoader worker process 0.
解决方法
既然维度不统一,那确保维度一样就好了,可以预先设置一个足够大的array或者tensor,没有填充的部分设置标记,读取到数据时根据标记确定有效数据。