python tips003 ——DataLoader的collate_fn参数使用详解

背景

最近在看sentences-transformers的源码,在有一个模块发现了dataloader.collate_fn,当时没搞懂是什么意思,后来查了一下,感觉还是很有意思的,因此来分享一下。

dataloader

dataloader肯定都是知道的,就是为数据提供一个迭代器。

基本工作机制:

在dataloader按照batch进行取数据的时候, 是取出大小等同于batch size的index列表,然后将列表列表中的index输入到dataset的getitem()函数中,取出该index对应的数据,最后, 对每个index对应的数据进行堆叠,就形成了一个batch的数据。

完整参数列表

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)
  1. shuffle:设置为True的时候,每个世代都会打乱数据集。
  2. collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能。
  3. drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留。

collate_fn作用

在最后一步堆叠的时候可能会出现问题: 如果一条数据中所含有的每个数据元的长度不同, 那么将无法进行堆叠. 如: multi-hot类型的数据, 序列数据。在使用这些数据时, 通常需要先进行长度上的补齐, 再进行堆叠. 以现在的流程, 是没有办法加入该操作的。此外, 某些优化方法是要对一个batch的数据进行操作。

collate_fn函数就是手动将抽取出的样本堆叠起来的函数。

案例说明

import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

test = np.arange(11)
input = torch.tensor(np.array([test[i:(i + 3)] for i in range(10 - 1)]))
target = torch.tensor(np.array([test[i:(i + 1)] for i in range(10 - 1)]))

torch_dataset = TensorDataset(input, target)
batch = 3


#> input data shape: torch.Size([9, 3])
#> target data shape: torch.Size([9, 1])

需要注意的是上面的input数据shape为(9, 3);target数据shape为(9,1)。我们设置每一次的batch为3

1. 不设置collate_fn参数

my_dataloader = DataLoader(
    dataset=torch_dataset,
    batch_size=batch
)
for (i, j) in my_dataloader:
    print('*' * 30)
    print(i)
    print(j)

查看上面的结果就可以看到每一批都返回两个结果,一个是input的样本,一个是target的样本。
python tips003 ——DataLoader的collate_fn参数使用详解
input样本、target样本的维度和原始保持一致,但是大小尺寸全部为batch。

2. 设置collate_fn参数为lambda x: x


my_dataloader = DataLoader(
    dataset=torch_dataset,
    batch_size=4,
    collate_fn=lambda x: x
)
for i in my_dataloader:
    print('*' * 30)
    print(i)

python tips003 ——DataLoader的collate_fn参数使用详解

这个时候每一批都是返回了一个列表,这个列表的大小为3,列表里面的每一个对象就是一个成对的input和target。

如果我们继续想把上面的列表解析成第一个的情况,我们可以这么做:

a = i
list((torch.cat([a[i][j].unsqueeze(0) for i in range(len(a))]).unsqueeze(0) for j in range(len(a[0]))))

python tips003 ——DataLoader的collate_fn参数使用详解

上面其实是很哇塞的,他其实是什么意思,就是把输出的长度为batch的列表转换为一个矩阵了。看着是挺复杂的,其实就是对list做了数据抽取和合并。非常简单。大概可以有这么个拆解路线:

my_dataloader = DataLoader(
    dataset=torch_dataset,
    batch_size=4,
    collate_fn=lambda x: x,
    drop_last=True
)
for i in my_dataloader:
    print('*' * 30)
    print(i)
  
a = i
a

然后,查看视频:

视频演示

3. 自定义collate_fn参数

现在结合上面的步骤,我们自定义自己的参数,然后实现默认的效果。大概代码如下:

my_dataloader = DataLoader(
    dataset=torch_dataset,
    batch_size=batch,
    collate_fn=lambda x:(
        torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))],dim=0) for j in range(len(x[0]))
    )
)
for i,j in my_dataloader:
    print('*' * 30)
    print(i)
    print(j)

python tips003 ——DataLoader的collate_fn参数使用详解

最后

  1. 后面会逐渐关于python更加冷门的东西,也会写pytorch的更多小细节。主要是用来记录自己的学习过程。将中间的一些比较复杂的东西给简单化。

参考链接

  1. https://blog.csdn.net/weixin_42028364/article/details/81675021
  2. https://zhuanlan.zhihu.com/p/361830892
  3. https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py
上一篇:Pytorch_DataLoader涉及内容


下一篇:mysql表名字段名批量改小写