本文以 PASCAL VOC2012 数据集为例子进行说明。(下载地址:PASCAL VOC2012)
Pytorch 自定义数据集见文档:TorchVision Object Detection Finetuning Tutorial
本文将以PASCAL VOC为基础自定义一个数据集VOCDataset
,并随机选取五张图片给将其对应的标注转化为矩形框画在图片上。
生成自定义数据集
一些需要导入的基本库
import os
import torch
import json
from torch.utils.data import Dataset
from PIL import Image
from os import path
from lxml import etree
# 类别数据
class_dict = {
"aeroplane": 1,
"bicycle": 2,
"bird": 3,
"boat": 4,
"bottle": 5,
"bus": 6,
"car": 7,
"cat": 8,
"chair": 9,
"cow": 10,
"diningtable": 11,
"dog": 12,
"horse": 13,
"motorbike": 14,
"person": 15,
"pottedplant": 16,
"sheep": 17,
"sofa": 18,
"train": 19,
"tvmonitor": 20
}
按照文档要求,在VOCDataset
中实现三个方法__len__
、__getitem__
、以及get_height_and_width
。
初始化 VOCDataset 类
构造函数定义如下
'''
voc_root: voc 数据集的根目录
year: 哪一个年份的数据集
transforms: 数据预处理
text_name: train.txt or val.txt 该txt文件在数据集的 VOCdevkit\VOC2012\ImageSets\Main 文件夹下
'''
def __init__(self, voc_root, year='2012', transforms=None, text_name='train.txt'):
在构造函数中,我们主要完成以下三个功能
- 设置图片路径
image_root
和标注路径anno_root
- 设置此次要训练的样本所有标注文件路径列表
xml_list
- 设置要检测的目标类别信息
class_dict
设置图片路径image_root
和标注路径anno_root
# 设置数据集、图片、标注的根目录
self.root = path.join(voc_root, 'VOCdevkit', f'VOC{year}')
self.image_root = path.join(self.root, 'JPEGImages')
self.anno_root = path.join(self.root, 'Annotations')
设置此次要训练的样本所有标注文件路径列表xml_list
# 根据 text_name 拿到对应的标注xml文件路径
text_path = path.join(self.root, 'ImageSets','Main', text_name)
# 读取txt文件的每一行并生成xml标注文件路径存放在xml_list中
with open(text_path) as file_reader:
self.xml_list = [
path.join(self.anno_root, f'{line.strip()}.xml')
for line in file_reader.readlines() if len(line.strip()) > 0
]
设置要检测的目标类别信息class_dict
self.class_dict = class_dict
一般使用 0 来表示当前类别是背景
获取所有样例条数
def __len__(self):
return len(self.xml_list)
样本的条数即标注文件列表长度
根据索引获取指定样本
函数定义如下
def __getitem__(self, idx):
传入的即为样本的索引值,其取值范围为 0 ~ len(xml_list)
获取指定样本需要分为如下两大步
- 获取图片
- 获取图片信息(标注信息、索引、区域面积等)
获取图片
首先我们需要根据索引拿到对应标注信息,并将其转化为json
格式
定义一个获取json
格式的annotation
的方法
def get_annotation(self, idx):
xml_path = self.xml_list[idx]
assert path.exists(xml_path), f'file {xml_path} not found'
xml_reader = open(xml_path)
xml_text = xml_reader.read()
xml = etree.fromstring(xml_text)
annotation = parse_xml_to_dict(xml)['annotation']
xml格式转化为json格式函数如下
def parse_xml_to_dict(xml):
if len(xml) == 0:
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = parse_xml_to_dict(child)
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else: # 一张图片中可能标注有多个 object
if child.tag not in result:
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
获取annotation
annotation = self.get_annotation(idx)
然后我们就可以从annotation
中拿到文件名称并获取到文件
image_path = path.join(self.image_root, annotation['filename'])
image = Image.open(image_path)
获取图片信息
声明需要获取的所有信息
# 生成 target
target = {
'boxes': [], # 标注的左上、右下坐标(xmin, ymin, xmax, ymax)
'labels': [],# 标注类别
'image_id': [], # 图片索引
'area': [], # 含有目标区域的面积 (xmax-xmin) * (ymax-ymin)
'iscrowd': [], # 是不是一堆密集的东西在一起
}
便利所有的object
for obj in annotation['object']:
bndbox = obj['bndbox']
xmin = float(bndbox['xmin'])
ymin = float(bndbox['ymin'])
xmax = float(bndbox['xmax'])
ymax = float(bndbox['ymax'])
target['boxes'].append([xmin, ymin, xmax, ymax]) # 设置有目标的坐标信息
target['labels'].append(self.class_dict[obj['name']]) # 获取对应的标签
target['area'].append((xmax - xmin) * (ymax - ymin)) # 计算面积
# 使用 difficult(当前目标是否难以识别) 字段来设置 iscrowd
if 'difficult' in obj:
target['iscrowd'].append(int(obj['difficult']))
else:
target['iscrowd'].append(0)
将所有信息转化为Tensor
# Convert to tensor
target['boxes'] = torch.as_tensor(target['boxes'])
target['labels'] = torch.as_tensor(target['labels'])
target['iscrowd'] = torch.as_tensor(target['iscrowd'])
target['area'] = torch.as_tensor(target['area'])
target['image_id'] = torch.tensor([idx])
如果有设置数据预处理器,则在返回数据前调用
if self.transforms is not None:
image = self.transforms(image)
返回图片以及对应的信息
return image, target
根据索引获取当前图片的宽高
在标注信息里面含有图片宽高信息,所以可以很容易获取到
def get_height_and_width(self, idx):
annotation = annotation = self.get_annotation(idx)
# 从 annotation 中取出宽高并返回
width = int(annotation['size']['width'])
height = int(annotation['size']['height'])
return height, width
以上我们就完成了数据集的定义,下面我们将使用实例代码来使用这个数据集
使用自定义数据集并画上标注框
导入一些基本库
import random
import matplotlib.pyplot as plt
import torchvision.transforms as ts
from draw_box_utils import draw_box
生成类别数据,将 k
、v
互换,便于查询
category_index = {}
category_index = {
v: k
for k, v in class_dict.items()
}
定义transformer
,将数据转化为Tensor
data_transform = ts.Compose([ts.ToTensor()])
由于ToTensor
会将数据标准化,为了代码简洁,这里不使用
拿到数据集并将目标框以及类别画出来
train_data_set = VOCDataset(os.getcwd(), '2012', None, 'train.txt')
for index in random.sample(range(0, len(train_data_set)), k=5):
image, target = train_data_set[index]
image = draw_bounding_boxes(
np.array(image),
target['boxes'],
target['labels'],
)
plt.imshow(image)
plt.show()
画目标框draw_bounding_boxes
代码如下(参考代码: vision/utils.py at main · pytorch/vision (github.com))
def draw_bounding_boxes(
image,
boxes: torch.Tensor,
labels: Optional[List[str]] = None
) -> torch.Tensor:
img_to_draw = Image.fromarray(image)
img_boxes = boxes.to(torch.int64).tolist()
draw = ImageDraw.Draw(img_to_draw)
for i, bbox in enumerate(img_boxes):
draw.rectangle(bbox, width=2, outline='red')
margin = 2
draw.text((bbox[0] + margin, bbox[1] + margin), category_index[labels[i] - 1], fill='red')
return np.array(img_to_draw)
这样就完成了整个流程了!
运行与测试
可见运行结果正确!