collate_fn 参数
当继承Dataset类自定义类时,__getitem__方法一般返回一组类似于(image,label)的一个样本,在创建DataLoader类的对象时,collate_fn函数会将batch_size个样本整理成一个batch样本,便于批量训练。
default_collate(batch)中的参数就是这里的 [self.dataset[i] for i in indices],indices是从所有样本的索引中选取的batch_size个索引,表示本次批量获取这些样本进行训练。self.dataset[i]就是自定义Dataset子类中__getitem__返回的结果。默认的函数default_collate(batch) 只能对大小相同image的batch_size个image整理,如[(img0, label0), (img1, label1),(img2, label2), ] 整理成([img0,img1,img2,], [label0,label1,label2,]), 这里要求多个img的size相同。所以在我们的图像大小不同时,需要自定义函数callate_fn来将batch个图像整理成统一大小的,若读取的数据有(img, box, label)这种你也需要自定义,因为默认只能处理(img,label)。当然你可以提前将数据集全部整理成统一大小的。
以下是文字识别时,文本行图像长度不一,需要自定义整理。
class AlignCollate(object):
"""将数据整理成batch"""
def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio_with_pad = keep_ratio_with_pad
def __call__(self, batch):# 有可能__getitem__返回的图像是None, 所以需要过滤掉
batch = filter(lambda x: x is not None, batch)
images, labels = zip(*batch)
if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper
resized_max_w = self.imgW
input_channel = 3 if images[0].mode == 'RGB' else 1
transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
resized_images = []
for image in images:
w, h = image.size
ratio = w / float(h)
# 图片的宽度大于设定的输入ingW
if math.ceil(self.imgH * ratio) > self.imgW:
resized_w = self.imgW
else:
resized_w = math.ceil(self.imgH * ratio)
resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
resized_images.append(transform(resized_image))
# resized_image.save('./image_test/%d_test.jpg' % w)
image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
else:
transform = ResizeNormalize((self.imgW, self.imgH))
image_tensors = [transform(image) for image in images]
image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)
return image_tensors, labels
再看个在做目标检测时自定义collate_fn函数,给每个图像添加索引
def collate_fn(self, batch):
paths, imgs, targets = list(zip(*batch))
# Remove empty placeholder targets
# 有可能__getitem__返回的图像是None, 所以需要过滤掉
targets = [boxes for boxes in targets if boxes is not None]
# Add sample index to targets
# boxes是每张图像上的目标框,但是每个图片上目标框数量不一样呢,所以需要给这些框添加上索引,对应到是哪个图像上的框。
for i, boxes in enumerate(targets):
boxes[:, 0] = i
targets = torch.cat(targets, 0)
# Selects new image size every tenth batch
if self.multiscale and self.batch_count % 10 == 0:
self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
# Resize images to input shape
# 每个图像大小不同呢,所以resize到统一大小
imgs = torch.stack([resize(img, self.img_size) for img in imgs])
self.batch_count += 1
return paths, imgs, targets
其实也可以自定义collate_fn同时,结合使用默认的default_collate
from torch.utils.data.dataloader import default_collate # 导入这个函数
def collate_fn(batch):
"""
params:
batch :是一个列表,列表的长度是 batch_size
列表的每一个元素是 (x,y) 这样的元组tuple,元祖的两个元素分别是x,y
大致的格式如下 [(x1,y1),(x2,y2),(x3,y3)...(xn,yn)]
returns:
整理之后的新的batch
"""
# 这一部分是对 batch 进行重新 “校对、整理”的代码
return default_collate(batch) #返回校对之后的batch,一般就直接推荐使用default_collate进行包装,因为它里面有很多功能,比如将numpy转化成tensor等操作,这是必须的。
Reference:
一文读懂Dataset, DataLoader及collate_fn, Sampler等参数