Yolov5系列(4)-dataloader模块

Abstract

数据的加载对于一个检测网络来说可以说是重中之重,在ICCV的很多论文中,提到了各种数据预处理的技巧,很幸运得是,yolov5将这些内容加入到了他的代码中,同时,代码的耦合度并不高,我们可以很轻松的将其移植到其他项目中去.

Introduction

1.数据加载主要步骤

支持的数据格式
yolov5支持直接将图片或者视频的数据转换到dataloader当中去.
文件类型包括:
‘bmp’, ‘jpg’, ‘jpeg’, ‘png’, ‘tif’, ‘tiff’, ‘dng’, ‘webp’, ‘mpo’
‘mov’, ‘avi’, ‘mp4’, ‘mpg’, ‘mpeg’, ‘m4v’, ‘wmv’, ‘mkv’

快速加载数据集标签的技巧
通常情况下,目标检测数据集都会非常大,同时,yolo采用的是txt格式保存的标签数据,那么加载数据可能会很慢,对于一个10万级别的数据集,可能花费1小时左右的时间,如果多次训练此数据集,显然是一笔很大的时间开销.因此,yolov5用来一种相当巧妙的方式,只会在第一次读取数据集的时候,将标签文件处理成一个cache文件,这样,再重新读取数据集的时候,直接搜索数据集中是否存在cache文件即可.
同时,相对于原始的txt标签文件,cache还可以保存很多的额外信息,比如说,数据集的大小,未找到标签文件对应的图片数量…
当然,细心的人会发现,如果使用一个损坏的cache文件存放到数据集中,yolov5会不会读取异常? 答案是不会,因为对于每一个数据集的cache文件中,会生成一个Hash Code,这个code是通过数据集的大小计算得到的,因此在读取cache之前会检查Hash Code是否匹配

在detect实际场景中数据加载的技巧
yolov5为我们提供了一个很完善的数据加载解决方案,它蕴含这很多的技巧,同时,核心代码数量只在1000行左右,便于使用者阅读.
对于train和test数据的加载,那么显然是需要从数据集中加载得到.yolov5提供了一个create_dataloader函数,可以在其参数中指定使用哪些加载技巧.然后只需要拿到加载数据集的迭代器就可以了
在检测阶段,如果你需要检测某一个文件夹中所有的图片,或者是从摄像头里面拉取视屏流,亦或者是从http协议中拉取视频流,yolov5都提供了一整套解决方案.相比于其他检测器的代码,你会发现,yolov5对于实际场景的检测做了很多的支持,不需要我们再单独部署模型.同时如果我们想将其部署到C++中,也可以参考其加载数据的代码技巧

数据训练的技巧
yolov5提供一个切割图片并将其重组的方法,它支持两种方式:切割为9份或者切割为4份. 默认情况下会切割四张图片进行重组, 对于切割9份的方法,从实际角度出发,如果存在图片中的目标比较小,同时数量比较少的情况,可以考虑这种方法来提高样本的平衡性.
具体做法:(以切割四份为例):
首先会准备四张图片,然后在这四张图片中分别随机切割一个区域,最后,将这四个区域重新组合成一张图片.这种做法对于数据量较少的数据集来说,可以很大程度的提高数据量,同时,这种方法还会让网络更好的适应极端的环境

datasets.py代码阅读

对于这一块代码,我们先看其加载训练数据的核心部分,它的入口是create_loader函数.
它会传入很多的超参,
包括path(数据集路径),imgsz(图片resize的大小),stride(在各个特征层下的感受野步长),opt(参数),hyp(参数),augument(是否扩大),cache(是否添加cache文件技巧)…

def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
                      rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
    # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
    with torch_distributed_zero_first(rank):
    	#LoadimagesAndLabels是其核心类,它继承在torch.utils.Datasets.Dataset
        dataset = LoadImagesAndLabels(path, imgsz, batch_size,
                                      augment=augment,  # augment images
                                      hyp=hyp,  # augmentation hyperparameters
                                      rect=rect,  # rectangular training
                                      cache_images=cache,
                                      single_cls=opt.single_cls,
                                      stride=int(stride),
                                      pad=pad,
                                      image_weights=image_weights,
                                      prefix=prefix)
    batch_size = min(batch_size, len(dataset))
    #num works,通常情况下是服务器cpu的数量,然而在windows情况下,这一参数需要被设置为0,不然会出现Broken Pipe的情况
    nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers])  # number of workers
    sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
    loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
    # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
    #将数据集变成数据加载器
    dataloader = loader(dataset,
                        batch_size=batch_size,
                        num_workers=nw,
                        sampler=sampler,
                        pin_memory=True,
                        collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
    return dataloader, dataset

我们会发现LoadImagesAndLabels是其核心类,找到代码部分

class LoadImagesAndLabels(Dataset):  # for training/testing
    def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
                 cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
        self.img_size = img_size
        self.augment = augment
        self.hyp = hyp
        self.image_weights = image_weights
        self.rect = False if image_weights else rect
        self.mosaic = self.augment and not self.rect  # load 4 images at a time into a mosaic (only during training)
        self.mosaic_border = [-img_size // 2, -img_size // 2]
        self.stride = stride
        self.path = path

        try:
            f = []  # image files
            #判断数据集的标签是文件夹还是文件
            for p in path if isinstance(path, list) else [path]:
                p = Path(p)  # os-agnostic
                if p.is_dir():  # dir
                    f += glob.glob(str(p / '**' / '*.*'), recursive=True)
                    # f = list(p.rglob('**/*.*'))  # pathlib
                elif p.is_file():  # file
                    with open(p, 'r') as t:
                        t = t.read().strip().splitlines()
                        parent = str(p.parent) + os.sep
                        f += [x.replace('./', parent) if x.startswith('./') else x for x in t]  # local to global path
                        # f += [p.parent / x.lstrip(os.sep) for x in t]  # local to global path (pathlib)
                else:
                    raise Exception(f'{prefix}{p} does not exist')
            #处理得到文件名
            self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
            # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats])  # pathlib
            assert self.img_files, f'{prefix}No images found'
        except Exception as e:
            raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')

        # Check cache
        #根据图片名字找到对应的labels
        self.label_files = img2label_paths(self.img_files)  # labels
        #判断数据集是否已经存在cache文件了
        cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')  # cached labels
        #如果存在
        if cache_path.is_file():
            cache, exists = torch.load(cache_path), True  # load
            #比较hashcode是否正确
            if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache:  # changed
            	#如果不正确,就重新生成cache文件
                cache, exists = self.cache_labels(cache_path, prefix), False  # re-cache
        else:
        	#生成cache文件
            cache, exists = self.cache_labels(cache_path, prefix), False  # cache

        # Display cache
        #cache文件中包含了很多其他信息
        #nf 已经根据标签寻找到的图片
        #nm 只有标签但是没有找到图片
        #ne 标签为空
        #corrupted 损坏的图片
        #total 总共的标签文件
        nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupted, total
        if exists:
            d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
            tqdm(None, desc=prefix + d, total=n, initial=n)  # display cache results
        assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
		
		#读取cache的信息,包括hash码,cache的版本号,以及cache的标签,图片的shape,标签和图片的路径等
        # Read cache
        cache.pop('hash')  # remove hash
        cache.pop('version')  # remove version
        labels, shapes, self.segments = zip(*cache.values())
        self.labels = list(labels)
        self.shapes = np.array(shapes, dtype=np.float64)
        self.img_files = list(cache.keys())  # update
        self.label_files = img2label_paths(cache.keys())  # update
        if single_cls:
            for x in self.labels:
                x[:, 0] = 0
		#图片数量
        n = len(shapes)  # number of images
        #图片数量/batch size
        bi = np.floor(np.arange(n) / batch_size).astype(np.int)  # batch index
        nb = bi[-1] + 1  # number of batches
        self.batch = bi  # batch index of image
        self.n = n
        self.indices = range(n)

        # Rectangular Training
        if self.rect:#获取图片resize的shape
            # Sort by aspect ratio
            s = self.shapes  # wh
            ar = s[:, 1] / s[:, 0]  # aspect ratio
            irect = ar.argsort()
            self.img_files = [self.img_files[i] for i in irect]
            self.label_files = [self.label_files[i] for i in irect]
            self.labels = [self.labels[i] for i in irect]
            self.shapes = s[irect]  # wh
            ar = ar[irect]

            # Set training image shapes
            shapes = [[1, 1]] * nb
            for i in range(nb):
                ari = ar[bi == i]
                mini, maxi = ari.min(), ari.max()
                if maxi < 1:
                    shapes[i] = [maxi, 1]
                elif mini > 1:
                    shapes[i] = [1, 1 / mini]#图片的尺寸

            self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride#图片回归到特征网格

        # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
        #将图片加载到内存中,此处使用了多线程加载图片,可以很快的提高速度
        #注:在
        self.imgs = [None] * n
        if cache_images:
            gb = 0  # Gigabytes of cached images
            self.img_hw0, self.img_hw = [None] * n, [None] * n
            results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))  # 8 threads
            pbar = tqdm(enumerate(results), total=n)
            for i, x in pbar:
                self.imgs[i], self.img_hw0[i], self.img_hw[i] = x  # img, hw_original, hw_resized = load_image(self, i)
                gb += self.imgs[i].nbytes
                pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
            pbar.close()

    def cache_labels(self, path=Path('./labels.cache'), prefix=''):
        # Cache dataset labels, check images and read shapes
        x = {}  # dict
        #初始化数据集的数量信息
        nm, nf, ne, nc = 0, 0, 0, 0  # number missing, found, empty, duplicate
        pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
        for i, (im_file, lb_file) in enumerate(pbar):
            try:
                # verify images
                #尝试打开图片文件
                im = Image.open(im_file)
                im.verify()  # PIL verify
                #获取图片的尺寸
                shape = exif_size(im)  # image size
                segments = []  # instance segments
                assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
                assert im.format.lower() in img_formats, f'invalid image format {im.format}'

                # verify labels
                #将label文件加载到cache中去
                if os.path.isfile(lb_file):
                    nf += 1  # label found
                    with open(lb_file, 'r') as f:
                    	#分割label.得到完整的gt信息
                        l = [x.split() for x in f.read().strip().splitlines()]
                        if any([len(x) > 8 for x in l]):  # is segment
                            classes = np.array([x[0] for x in l], dtype=np.float32)
                            segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l]  # (cls, xy1...)
                            l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1)  # (cls, xywh)
                        l = np.array(l, dtype=np.float32)
                    if len(l):
                    	#查看label文件的格式是否符合要求
                        assert l.shape[1] == 5, 'labels require 5 columns each'
                        assert (l >= 0).all(), 'negative labels'
                        assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
                        assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
                    else:
                        ne += 1  # label empty
                        l = np.zeros((0, 5), dtype=np.float32)
                else:
                    nm += 1  # label missing
                    l = np.zeros((0, 5), dtype=np.float32)
                x[im_file] = [l, shape, segments]
            except Exception as e:
                nc += 1
                print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')

            pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
                        f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
        pbar.close()

        if nf == 0:
            print(f'{prefix}WARNING: No labels found in {path}. See {help_url}')

        x['hash'] = get_hash(self.label_files + self.img_files)
        x['results'] = nf, nm, ne, nc, i + 1
        x['version'] = 0.1  # cache version
        
        torch.save(x, path)  # save for next time
        logging.info(f'{prefix}New cache created: {path}')
        return x

    def __len__(self):
    	
        return len(self.img_files)

    # def __iter__(self):
    #     self.count = -1
    #     print('ran dataset iter')
    #     #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
    #     return self

    def __getitem__(self, index):
        index = self.indices[index]  # linear, shuffled, or image_weights

        hyp = self.hyp
        mosaic = self.mosaic and random.random() < hyp['mosaic']
        #对图片进行切割
        if mosaic:
            # Load mosaic
            img, labels = load_mosaic(self, index)
            shapes = None

            # MixUp https://arxiv.org/pdf/1710.09412.pdf
            if random.random() < hyp['mixup']:
                img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
                r = np.random.beta(8.0, 8.0)  # mixup ratio, alpha=beta=8.0
                img = (img * r + img2 * (1 - r)).astype(np.uint8)
                labels = np.concatenate((labels, labels2), 0)

        else:
            # Load image
            img, (h0, w0), (h, w) = load_image(self, index)

            # Letterbox
            #对图片以及标签的size进行处理
            shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size  # final letterboxed shape
            img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
            shapes = (h0, w0), ((h / h0, w / w0), pad)  # for COCO mAP rescaling

            labels = self.labels[index].copy()
            if labels.size:  # normalized xywh to pixel xyxy format
                labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])

        if self.augment:
            # Augment imagespace
            if not mosaic:
                img, labels = random_perspective(img, labels,
                                                 degrees=hyp['degrees'],
                                                 translate=hyp['translate'],
                                                 scale=hyp['scale'],
                                                 shear=hyp['shear'],
                                                 perspective=hyp['perspective'])

            # Augment colorspace
            augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])

            # Apply cutouts
            # if random.random() < 0.9:
            #     labels = cutout(img, labels)

        nL = len(labels)  # number of labels
        
        if nL:
            labels[:, 1:5] = xyxy2xywh(labels[:, 1:5])  # convert xyxy to xywh
            labels[:, [2, 4]] /= img.shape[0]  # normalized height 0-1
            labels[:, [1, 3]] /= img.shape[1]  # normalized width 0-1
       

        if self.augment:
       
            # flip up-down
            if random.random() < hyp['flipud']:
                img = np.flipud(img)
                if nL:
                    labels[:, 2] = 1 - labels[:, 2]

            # flip left-right
            if random.random() < hyp['fliplr']:
                img = np.fliplr(img)
                if nL:
                    labels[:, 1] = 1 - labels[:, 1]

        labels_out = torch.zeros((nL, 6))
        

        if nL:
            labels_out[:, 1:] = torch.from_numpy(labels)
      
        # Convert
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = np.ascontiguousarray(img)

        return torch.from_numpy(img), labels_out, self.img_files[index], shapes

    @staticmethod
    def collate_fn(batch):
        img, label, path, shapes = zip(*batch)  # transposed
        for i, l in enumerate(label):
            l[:, 0] = i  # add target image index for build_targets()
        return torch.stack(img, 0), torch.cat(label, 0), path, shapes

    @staticmethod
    #将图像切割成四份重新组合在一起
    def collate_fn4(batch):
        img, label, path, shapes = zip(*batch)  # transposed
        n = len(shapes) // 4
        img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]

        ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
        wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
        s = torch.tensor([[1, 1, .5, .5, .5, .5]])  # scale
        for i in range(n):  # zidane torch.zeros(16,3,720,1280)  # BCHW
            i *= 4
            if random.random() < 0.5:
                im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
                    0].type(img[i].type())
                l = label[i]
            else:
                im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
                l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
            img4.append(im)
            label4.append(l)

        for i, l in enumerate(label4):
            l[:, 0] = i  # add target image index for build_targets()

        return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4

上一篇:Yolov5训练自己的数据集(windows10)


下一篇:第一章:python最新YOLOv5-4.0环境搭建,零基础小白都能看得懂的教程。YOLOv5搭建的最快搭建方式,踩坑经历详谈)