Abstract
数据的加载对于一个检测网络来说可以说是重中之重,在ICCV的很多论文中,提到了各种数据预处理的技巧,很幸运得是,yolov5将这些内容加入到了他的代码中,同时,代码的耦合度并不高,我们可以很轻松的将其移植到其他项目中去.
Introduction
1.数据加载主要步骤
支持的数据格式
yolov5支持直接将图片或者视频的数据转换到dataloader当中去.
文件类型包括:
‘bmp’, ‘jpg’, ‘jpeg’, ‘png’, ‘tif’, ‘tiff’, ‘dng’, ‘webp’, ‘mpo’
‘mov’, ‘avi’, ‘mp4’, ‘mpg’, ‘mpeg’, ‘m4v’, ‘wmv’, ‘mkv’
快速加载数据集标签的技巧
通常情况下,目标检测数据集都会非常大,同时,yolo采用的是txt格式保存的标签数据,那么加载数据可能会很慢,对于一个10万级别的数据集,可能花费1小时左右的时间,如果多次训练此数据集,显然是一笔很大的时间开销.因此,yolov5用来一种相当巧妙的方式,只会在第一次读取数据集的时候,将标签文件处理成一个cache文件,这样,再重新读取数据集的时候,直接搜索数据集中是否存在cache文件即可.
同时,相对于原始的txt标签文件,cache还可以保存很多的额外信息,比如说,数据集的大小,未找到标签文件对应的图片数量…
当然,细心的人会发现,如果使用一个损坏的cache文件存放到数据集中,yolov5会不会读取异常? 答案是不会,因为对于每一个数据集的cache文件中,会生成一个Hash Code,这个code是通过数据集的大小计算得到的,因此在读取cache之前会检查Hash Code是否匹配
在detect实际场景中数据加载的技巧
yolov5为我们提供了一个很完善的数据加载解决方案,它蕴含这很多的技巧,同时,核心代码数量只在1000行左右,便于使用者阅读.
对于train和test数据的加载,那么显然是需要从数据集中加载得到.yolov5提供了一个create_dataloader函数,可以在其参数中指定使用哪些加载技巧.然后只需要拿到加载数据集的迭代器就可以了
在检测阶段,如果你需要检测某一个文件夹中所有的图片,或者是从摄像头里面拉取视屏流,亦或者是从http协议中拉取视频流,yolov5都提供了一整套解决方案.相比于其他检测器的代码,你会发现,yolov5对于实际场景的检测做了很多的支持,不需要我们再单独部署模型.同时如果我们想将其部署到C++中,也可以参考其加载数据的代码技巧
数据训练的技巧
yolov5提供一个切割图片并将其重组的方法,它支持两种方式:切割为9份或者切割为4份. 默认情况下会切割四张图片进行重组, 对于切割9份的方法,从实际角度出发,如果存在图片中的目标比较小,同时数量比较少的情况,可以考虑这种方法来提高样本的平衡性.
具体做法:(以切割四份为例):
首先会准备四张图片,然后在这四张图片中分别随机切割一个区域,最后,将这四个区域重新组合成一张图片.这种做法对于数据量较少的数据集来说,可以很大程度的提高数据量,同时,这种方法还会让网络更好的适应极端的环境
datasets.py代码阅读
对于这一块代码,我们先看其加载训练数据的核心部分,它的入口是create_loader函数.
它会传入很多的超参,
包括path(数据集路径),imgsz(图片resize的大小),stride(在各个特征层下的感受野步长),opt(参数),hyp(参数),augument(是否扩大),cache(是否添加cache文件技巧)…
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
#LoadimagesAndLabels是其核心类,它继承在torch.utils.Datasets.Dataset
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augment images
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
cache_images=cache,
single_cls=opt.single_cls,
stride=int(stride),
pad=pad,
image_weights=image_weights,
prefix=prefix)
batch_size = min(batch_size, len(dataset))
#num works,通常情况下是服务器cpu的数量,然而在windows情况下,这一参数需要被设置为0,不然会出现Broken Pipe的情况
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
#将数据集变成数据加载器
dataloader = loader(dataset,
batch_size=batch_size,
num_workers=nw,
sampler=sampler,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
return dataloader, dataset
我们会发现LoadImagesAndLabels是其核心类,找到代码部分
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
self.img_size = img_size
self.augment = augment
self.hyp = hyp
self.image_weights = image_weights
self.rect = False if image_weights else rect
self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride
self.path = path
try:
f = [] # image files
#判断数据集的标签是文件夹还是文件
for p in path if isinstance(path, list) else [path]:
p = Path(p) # os-agnostic
if p.is_dir(): # dir
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
# f = list(p.rglob('**/*.*')) # pathlib
elif p.is_file(): # file
with open(p, 'r') as t:
t = t.read().strip().splitlines()
parent = str(p.parent) + os.sep
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
else:
raise Exception(f'{prefix}{p} does not exist')
#处理得到文件名
self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib
assert self.img_files, f'{prefix}No images found'
except Exception as e:
raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
# Check cache
#根据图片名字找到对应的labels
self.label_files = img2label_paths(self.img_files) # labels
#判断数据集是否已经存在cache文件了
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels
#如果存在
if cache_path.is_file():
cache, exists = torch.load(cache_path), True # load
#比较hashcode是否正确
if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache: # changed
#如果不正确,就重新生成cache文件
cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
else:
#生成cache文件
cache, exists = self.cache_labels(cache_path, prefix), False # cache
# Display cache
#cache文件中包含了很多其他信息
#nf 已经根据标签寻找到的图片
#nm 只有标签但是没有找到图片
#ne 标签为空
#corrupted 损坏的图片
#total 总共的标签文件
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
if exists:
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
#读取cache的信息,包括hash码,cache的版本号,以及cache的标签,图片的shape,标签和图片的路径等
# Read cache
cache.pop('hash') # remove hash
cache.pop('version') # remove version
labels, shapes, self.segments = zip(*cache.values())
self.labels = list(labels)
self.shapes = np.array(shapes, dtype=np.float64)
self.img_files = list(cache.keys()) # update
self.label_files = img2label_paths(cache.keys()) # update
if single_cls:
for x in self.labels:
x[:, 0] = 0
#图片数量
n = len(shapes) # number of images
#图片数量/batch size
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
nb = bi[-1] + 1 # number of batches
self.batch = bi # batch index of image
self.n = n
self.indices = range(n)
# Rectangular Training
if self.rect:#获取图片resize的shape
# Sort by aspect ratio
s = self.shapes # wh
ar = s[:, 1] / s[:, 0] # aspect ratio
irect = ar.argsort()
self.img_files = [self.img_files[i] for i in irect]
self.label_files = [self.label_files[i] for i in irect]
self.labels = [self.labels[i] for i in irect]
self.shapes = s[irect] # wh
ar = ar[irect]
# Set training image shapes
shapes = [[1, 1]] * nb
for i in range(nb):
ari = ar[bi == i]
mini, maxi = ari.min(), ari.max()
if maxi < 1:
shapes[i] = [maxi, 1]
elif mini > 1:
shapes[i] = [1, 1 / mini]#图片的尺寸
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride#图片回归到特征网格
# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
#将图片加载到内存中,此处使用了多线程加载图片,可以很快的提高速度
#注:在
self.imgs = [None] * n
if cache_images:
gb = 0 # Gigabytes of cached images
self.img_hw0, self.img_hw = [None] * n, [None] * n
results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 8 threads
pbar = tqdm(enumerate(results), total=n)
for i, x in pbar:
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
gb += self.imgs[i].nbytes
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
pbar.close()
def cache_labels(self, path=Path('./labels.cache'), prefix=''):
# Cache dataset labels, check images and read shapes
x = {} # dict
#初始化数据集的数量信息
nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate
pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
for i, (im_file, lb_file) in enumerate(pbar):
try:
# verify images
#尝试打开图片文件
im = Image.open(im_file)
im.verify() # PIL verify
#获取图片的尺寸
shape = exif_size(im) # image size
segments = [] # instance segments
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
assert im.format.lower() in img_formats, f'invalid image format {im.format}'
# verify labels
#将label文件加载到cache中去
if os.path.isfile(lb_file):
nf += 1 # label found
with open(lb_file, 'r') as f:
#分割label.得到完整的gt信息
l = [x.split() for x in f.read().strip().splitlines()]
if any([len(x) > 8 for x in l]): # is segment
classes = np.array([x[0] for x in l], dtype=np.float32)
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
l = np.array(l, dtype=np.float32)
if len(l):
#查看label文件的格式是否符合要求
assert l.shape[1] == 5, 'labels require 5 columns each'
assert (l >= 0).all(), 'negative labels'
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
else:
ne += 1 # label empty
l = np.zeros((0, 5), dtype=np.float32)
else:
nm += 1 # label missing
l = np.zeros((0, 5), dtype=np.float32)
x[im_file] = [l, shape, segments]
except Exception as e:
nc += 1
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
pbar.close()
if nf == 0:
print(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
x['hash'] = get_hash(self.label_files + self.img_files)
x['results'] = nf, nm, ne, nc, i + 1
x['version'] = 0.1 # cache version
torch.save(x, path) # save for next time
logging.info(f'{prefix}New cache created: {path}')
return x
def __len__(self):
return len(self.img_files)
# def __iter__(self):
# self.count = -1
# print('ran dataset iter')
# #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
# return self
def __getitem__(self, index):
index = self.indices[index] # linear, shuffled, or image_weights
hyp = self.hyp
mosaic = self.mosaic and random.random() < hyp['mosaic']
#对图片进行切割
if mosaic:
# Load mosaic
img, labels = load_mosaic(self, index)
shapes = None
# MixUp https://arxiv.org/pdf/1710.09412.pdf
if random.random() < hyp['mixup']:
img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
img = (img * r + img2 * (1 - r)).astype(np.uint8)
labels = np.concatenate((labels, labels2), 0)
else:
# Load image
img, (h0, w0), (h, w) = load_image(self, index)
# Letterbox
#对图片以及标签的size进行处理
shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
labels = self.labels[index].copy()
if labels.size: # normalized xywh to pixel xyxy format
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
if self.augment:
# Augment imagespace
if not mosaic:
img, labels = random_perspective(img, labels,
degrees=hyp['degrees'],
translate=hyp['translate'],
scale=hyp['scale'],
shear=hyp['shear'],
perspective=hyp['perspective'])
# Augment colorspace
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
# Apply cutouts
# if random.random() < 0.9:
# labels = cutout(img, labels)
nL = len(labels) # number of labels
if nL:
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
if self.augment:
# flip up-down
if random.random() < hyp['flipud']:
img = np.flipud(img)
if nL:
labels[:, 2] = 1 - labels[:, 2]
# flip left-right
if random.random() < hyp['fliplr']:
img = np.fliplr(img)
if nL:
labels[:, 1] = 1 - labels[:, 1]
labels_out = torch.zeros((nL, 6))
if nL:
labels_out[:, 1:] = torch.from_numpy(labels)
# Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img)
return torch.from_numpy(img), labels_out, self.img_files[index], shapes
@staticmethod
def collate_fn(batch):
img, label, path, shapes = zip(*batch) # transposed
for i, l in enumerate(label):
l[:, 0] = i # add target image index for build_targets()
return torch.stack(img, 0), torch.cat(label, 0), path, shapes
@staticmethod
#将图像切割成四份重新组合在一起
def collate_fn4(batch):
img, label, path, shapes = zip(*batch) # transposed
n = len(shapes) // 4
img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale
for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
i *= 4
if random.random() < 0.5:
im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
0].type(img[i].type())
l = label[i]
else:
im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
img4.append(im)
label4.append(l)
for i, l in enumerate(label4):
l[:, 0] = i # add target image index for build_targets()
return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4