torch.utils.data.DataLoader之简易理解(小白进)

官方解释:Dataloader 组合了 dataset & sampler,提供在数据上的 iterable

主要参数:

1、dataset:这个dataset一定要是torch.utils.data.Dataset本身或继承自它的类

里面最主要的方法是 __getitem__(self, index) 用于根据index索引来取数据的

2、batch_size:每个batch批次要返回几条数据

3、shuffle:是否打乱数据,默认False

4、sampler:sample strategy,数据选取策略,有它就不用shuffle了,因为sample本身就是一种无序。这个sampler貌似也一定要是torch.utils.data.sampler.Sampler本身或继承自它的类。

里面最主要的方法是__iter__(self) 方法,每次调用 iter 只能获取 batchsize 个数据,也就是一个批次的数据。

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

5、…… 后面就不说了

这里先贴一段我的代码:

trainloader = DataLoader(
            ImageDataset(self.dataset.train, transform=self.transform_train),
            # 为传入的数据中的每个id选择config.k个样本
            sampler=ClassUniformlySampler(self.dataset.train, class_position=1, k=config.k),  # 传入的数据中第2维是类别,所以class_position=1
            batch_size=config.p * config.k, num_workers=config.workers,
            # shuffle=True, # 有了ClassUniformlySampler就不用shuffle了
            pin_memory=pin_memory, drop_last=False
        )

for batch_idx, (imgs, pids, _) in enumerate(trainloader):
    print("batch_idx: ", batch_idx)
    for i in range(len(pids)):
        print(pids[i], imgs[i].shape)

一开始我并不真的明白它的内部原理,为什么执行 enumerate 代码,就可以源源不断地返回所需数据,后来我跟了一下整个代码才明白(如果你也在这个地方范迷糊,可以继续往下看,如果没有,则可离开

       在执行 trainloader = DataLoader()语句的时候,DataLoder,ImageDataset,ClassUniformlySampler 并没有什么特殊的操作,都仅仅是init初始化了一下。

        这里所使用的 ClassUniformlySampler ,是Sampler类的一种,作用是对数据中的所有id仅保留k条数据。因此它在初始化时,生成了一个字典,key为类别,value为属于该类别的所有数据的索引。(这里仅作讲解使用,无需深入学习

代码取自他处,已难寻根,在此标注一下**

class ClassUniformlySampler(Sampler):
    '''
    random sample according to class label
    Arguments:
        data_source (Dataset): data_loader to sample from
        class_position (int): which one is used as class
        k (int): sample k images of each class
    '''

    def __init__(self, data_source, class_position, k):

        self.class_position = class_position
        self.k = k

        self.samples = data_source
        self.class_dict = self._tuple2dict(self.samples)  # 返回一个字典,key为类别,value为属于该类别的所有数据的索引

    def __iter__(self):
        self.sample_list = self._generate_list(self.class_dict)
        return iter(self.sample_list)

    def __len__(self):
        return len(self.sample_list)

    def _tuple2dict(self, inputs):
        '''

        :param inputs: list with tuple elemnts, [(image_path1, class_index_1), (imagespath_2, class_index_2), ...]
        :return: dict, {class_index_i: [samples_index1, samples_index2, ...]}
        '''
        dict = {}
        for index, each_input in enumerate(inputs):
            class_index = each_input[self.class_position]
            if class_index not in list(dict.keys()):
                dict[class_index] = [index]
            else:
                dict[class_index].append(index)
        return dict

    def _generate_list(self, dict):
        '''
        :param dict: dict, whose values are list
        :return:
        '''

        sample_list = []

        dict_copy = dict.copy()
        keys = list(dict_copy.keys())
        random.shuffle(keys)
        for key in keys:
            value = dict_copy[key]
            if len(value) >= self.k:
                random.shuffle(value)
                sample_list.extend(value[0: self.k])
            else:
                value = value * self.k
                random.shuffle(value)
                sample_list.extend(value[0: self.k])

        return sample_list

        在第一次执行 for batch_idx, (imgs, pids, _) in enumerate(trainloader) 时,首先调用的是sampler.__iter__() 方法,对所有数据进行采样后返回一个存储了所采样的数据的索引列表,并用iter(sampler_list) 作为返回。iter方法在一开始已经提及,每次调用只能返回 batchsize 条数据。

        随后,Dataset就上场了,它只需根据 sampler_list 中的索引挨个取数据即可,取到第 batchsize 条数据的时候,iter 就不会再让它取了。

        这之后,每一次执行 for batch_idx, (imgs, pids, _) in enumerate(trainloader) 时,Dataset 都会从上一次iter中断的数据索引处继续取 batchsize 个数据,直到取完所有数据。

注:因为在采样时,已经打乱了原有的数据顺序,对于采样后返回的sample_list,即使按顺序取,也不是真的有序,而且这样还可以防止重复抽取到相同数据,数据取完就可以结束一个epoch

上一篇:简单七步搭建Docker环境(CentOS 8)


下一篇:前端打包工具-webpack