CV(一)之自定义数据集

本文以 PASCAL VOC2012 数据集为例子进行说明。(下载地址:PASCAL VOC2012)

Pytorch 自定义数据集见文档:TorchVision Object Detection Finetuning Tutorial

本文将以PASCAL VOC为基础自定义一个数据集VOCDataset,并随机选取五张图片给将其对应的标注转化为矩形框画在图片上。

CV(一)之自定义数据集

生成自定义数据集

一些需要导入的基本库

import os
import torch
import json
from torch.utils.data import Dataset
from PIL import Image
from os import path
from lxml import etree

# 类别数据
class_dict = {
    "aeroplane": 1,
    "bicycle": 2,
    "bird": 3,
    "boat": 4,
    "bottle": 5,
    "bus": 6,
    "car": 7,
    "cat": 8,
    "chair": 9,
    "cow": 10,
    "diningtable": 11,
    "dog": 12,
    "horse": 13,
    "motorbike": 14,
    "person": 15,
    "pottedplant": 16,
    "sheep": 17,
    "sofa": 18,
    "train": 19,
    "tvmonitor": 20
}

按照文档要求,在VOCDataset中实现三个方法__len____getitem__、以及get_height_and_width

初始化 VOCDataset 类

构造函数定义如下

'''
voc_root: voc 数据集的根目录
year: 哪一个年份的数据集
transforms: 数据预处理
text_name: train.txt or val.txt 该txt文件在数据集的 VOCdevkit\VOC2012\ImageSets\Main 文件夹下
'''
def __init__(self, voc_root, year='2012', transforms=None, text_name='train.txt'):

在构造函数中,我们主要完成以下三个功能

  1. 设置图片路径image_root和标注路径anno_root
  2. 设置此次要训练的样本所有标注文件路径列表xml_list
  3. 设置要检测的目标类别信息class_dict
设置图片路径image_root和标注路径anno_root
        # 设置数据集、图片、标注的根目录
        self.root = path.join(voc_root, 'VOCdevkit', f'VOC{year}')
        self.image_root = path.join(self.root, 'JPEGImages')
        self.anno_root = path.join(self.root, 'Annotations')
设置此次要训练的样本所有标注文件路径列表xml_list
        # 根据 text_name 拿到对应的标注xml文件路径
        text_path = path.join(self.root, 'ImageSets','Main', text_name)
        # 读取txt文件的每一行并生成xml标注文件路径存放在xml_list中
        with open(text_path) as file_reader:
            self.xml_list = [
                path.join(self.anno_root, f'{line.strip()}.xml')
                for line in file_reader.readlines() if len(line.strip()) > 0
            ]
设置要检测的目标类别信息class_dict
        self.class_dict = class_dict

一般使用 0 来表示当前类别是背景

获取所有样例条数

    def __len__(self):
        return len(self.xml_list)

样本的条数即标注文件列表长度

根据索引获取指定样本

函数定义如下

    def __getitem__(self, idx):

传入的即为样本的索引值,其取值范围为 0 ~ len(xml_list)

获取指定样本需要分为如下两大步

  1. 获取图片
  2. 获取图片信息(标注信息、索引、区域面积等)
获取图片

首先我们需要根据索引拿到对应标注信息,并将其转化为json格式
定义一个获取json格式的annotation的方法

    def get_annotation(self, idx):
        xml_path = self.xml_list[idx]
        assert path.exists(xml_path), f'file {xml_path} not found'

        xml_reader = open(xml_path)
        xml_text = xml_reader.read()
        xml = etree.fromstring(xml_text)
        annotation = parse_xml_to_dict(xml)['annotation']

xml格式转化为json格式函数如下

def parse_xml_to_dict(xml):
    if len(xml) == 0:
        return {xml.tag: xml.text}
    
    result = {}
    for child in xml:
        child_result = parse_xml_to_dict(child)
        if child.tag != 'object':
            result[child.tag] = child_result[child.tag]
        else: # 一张图片中可能标注有多个 object
            if child.tag not in result:
                result[child.tag] = []
            result[child.tag].append(child_result[child.tag])
    
    return {xml.tag: result}

获取annotation

        annotation = self.get_annotation(idx)

然后我们就可以从annotation中拿到文件名称并获取到文件

        image_path = path.join(self.image_root, annotation['filename'])
        image = Image.open(image_path)
获取图片信息

声明需要获取的所有信息

        # 生成 target
        target = {
            'boxes': [], # 标注的左上、右下坐标(xmin, ymin, xmax, ymax)
            'labels': [],# 标注类别
            'image_id': [], # 图片索引
            'area': [], # 含有目标区域的面积 (xmax-xmin) * (ymax-ymin)
            'iscrowd': [], # 是不是一堆密集的东西在一起
        }

便利所有的object


        for obj in annotation['object']:
            bndbox = obj['bndbox']
            xmin = float(bndbox['xmin'])
            ymin = float(bndbox['ymin'])
            xmax = float(bndbox['xmax'])
            ymax = float(bndbox['ymax'])
            target['boxes'].append([xmin, ymin, xmax, ymax]) # 设置有目标的坐标信息
            target['labels'].append(self.class_dict[obj['name']]) # 获取对应的标签
            target['area'].append((xmax - xmin) * (ymax - ymin)) # 计算面积

            # 使用 difficult(当前目标是否难以识别) 字段来设置 iscrowd
            if 'difficult' in obj:
                target['iscrowd'].append(int(obj['difficult']))
            else:
                target['iscrowd'].append(0)

将所有信息转化为Tensor

        # Convert to tensor
        target['boxes'] = torch.as_tensor(target['boxes'])
        target['labels'] = torch.as_tensor(target['labels'])
        target['iscrowd'] = torch.as_tensor(target['iscrowd'])
        target['area'] = torch.as_tensor(target['area'])
        target['image_id'] = torch.tensor([idx])

如果有设置数据预处理器,则在返回数据前调用

        if self.transforms is not None:
            image = self.transforms(image)

返回图片以及对应的信息

        return image, target

根据索引获取当前图片的宽高

在标注信息里面含有图片宽高信息,所以可以很容易获取到

    def get_height_and_width(self, idx):
        annotation = annotation = self.get_annotation(idx)
        # 从 annotation 中取出宽高并返回
        width = int(annotation['size']['width'])
        height = int(annotation['size']['height'])

        return height, width

以上我们就完成了数据集的定义,下面我们将使用实例代码来使用这个数据集

使用自定义数据集并画上标注框

导入一些基本库

import random
import matplotlib.pyplot as plt
import torchvision.transforms as ts
from draw_box_utils import draw_box

生成类别数据,将 kv 互换,便于查询

category_index = {}

category_index = {
    v: k
    for k, v in class_dict.items()
}

定义transformer,将数据转化为Tensor

data_transform = ts.Compose([ts.ToTensor()])

由于ToTensor会将数据标准化,为了代码简洁,这里不使用

拿到数据集并将目标框以及类别画出来

train_data_set = VOCDataset(os.getcwd(), '2012', None, 'train.txt')

for index in random.sample(range(0, len(train_data_set)), k=5):
    image, target = train_data_set[index]
    image = draw_bounding_boxes(
        np.array(image),
        target['boxes'],
        target['labels'],
    )
    plt.imshow(image)
    plt.show()

画目标框draw_bounding_boxes代码如下(参考代码: vision/utils.py at main · pytorch/vision (github.com))


def draw_bounding_boxes(
    image,
    boxes: torch.Tensor,
    labels: Optional[List[str]] = None
) -> torch.Tensor:
    img_to_draw = Image.fromarray(image)
    img_boxes = boxes.to(torch.int64).tolist()
    draw = ImageDraw.Draw(img_to_draw)

    for i, bbox in enumerate(img_boxes):
        draw.rectangle(bbox, width=2, outline='red')
        margin = 2
        draw.text((bbox[0] + margin, bbox[1] + margin),  category_index[labels[i] - 1], fill='red')


    return np.array(img_to_draw)

这样就完成了整个流程了!

运行与测试

CV(一)之自定义数据集

可见运行结果正确!

上一篇:springboot之web项目自定义拦截器


下一篇:SpringBoot注解之@Configuration、@Bean、@Component