torch.utils.data.DataLoader
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
官方文档的链接
它是PyTorch中数据读取的一个重要接口,该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。
DataLoader的所有参数如上所示,接下来依次对每个参数介绍
Dataset Types
DataLoader构造函数最重要的参数是dataset,它表示要从中加载数据的dataset对象。PyTorch支持两种不同类型的数据集:
- map-style datasets
这种类型的数据集可以实现__getitem__() and __len__()
方法,并且表示为从indices/keys到数据样本的映射。例如,当使用dataset[idx]访问这样的数据集时,可以从磁盘上的文件夹中读取idx-th图像及其对应的标签。 - iterable-style datasets.
该类数据集是IterableDataset子类的一个实例,可以实现__iter__()
方法,表示对数据样本进行迭代。这种类型的数据集特别适合这样的情况,即随机读取非常昂贵,甚至是不可能的,并且batch size取决于获取的数据。例如,当iter(dataset)
时,可以返回从数据库、远程服务器甚至实时生成的日志读取的数据流。
batch_size (python:int, optional)
每批加载多少个样本(默认值:1)。
shuffle (bool, optional)
设置为True,以便在每个epoch重新洗牌数据(默认为False)。
sampler (Sampler, optional)
定义从数据集提取样本的策略。如果指定,则shuffle必须为False。
batch_sampler (Sampler, optional)
类似于sampler,但一次返回一批索引。与batch_size, shuffle, sampler, and drop_last.相互排斥
num_workers (python:int, optional)
要使用多少子进程来加载数据。0表示将在主进程中加载数据。(默认值:0)。
collate_fn (callable, optional)
将一个样本列表合并成一个张量的小批量。当使用批量加载从map样式的数据集时使用。
pin_memory (bool, optional)
如果是,数据加载器将把张量复制到CUDA固定内存中,然后再返回它们。
drop_last (bool, optional)
如果数据集大小不能被批处理大小整除,则将其设置为True以删除最后一个未完成的批处理。如果为False且数据集的大小不能被批处理大小整除,则最后一批数据将更小。(默认值:False)
timeout (numeric, optional)
超时,默认为0。是用来设置数据读取的超时时间的,超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。