简介
上一节实现了加载配置,加载配置文件可以方便的进行参数的修改,这一节实现加载数据。
DataLoader
我使用的数据是MLT2017的数据集,在其中把法语的分割出来了,数据集下载地址:法语OCR识别数据集
其中解压后包含训练集图片文件夹、测试集图片文件夹、训练集标签文件和测试集标签文件以及字典文件。
数据可以放置在工程的data文件夹下或者你喜欢的位置,加载数据的代码自然就放在data文件夹下,命名dataset.py:
import torch.utils.data as data # 加载torch的数据加载器
import numpy as np
import time
import cv2
import sys
import os
sys.path.append(os.getcwd())
# 实现模板类
class OCRDataset(data.Dataset):
def __init__(self,config,is_train=True):
self.root = config.DATASET.ROOT
self.is_train = is_train
self.inp_h = config.MODEL.IMAGE_SIZE.H
self.inp_w = config.MODEL.IMAGE_SIZE.W
self.dataset_name = config.DATASET.DATASET
self.mean = np.array(config.DATASET.MEAN, dtype=np.float32)
self.std = np.array(config.DATASET.STD, dtype=np.float32)
char_file = config.DATASET.CHAR_FILE
txt_file = config.DATASET.JSON_FILE['train'] if is_train else config.DATASET.JSON_FILE['val']
txt_file = os.path.join(self.root,txt_file)
# convert name:indices to name:string
self.labels = []
with open(txt_file, 'r', encoding='utf-8') as file:
contents = file.readlines()
for c in contents:
imgname = c.split('\t')[0]
string = c.split('\t')[1].replace("\n","")
self.labels.append({imgname: string})
print("load {} images!".format(self.__len__()))
def __len__(self):
# 实现模板方法
return len(self.labels)
def __getitem__(self,idx):
img_name = list(self.labels[idx].keys())[0]
img = cv2.imread(os.path.join(self.root, img_name))
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_h, img_w = img.shape
img = cv2.resize(img, (0,0), fx=self.inp_w / img_w, fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC)
img = np.reshape(img, (self.inp_h, self.inp_w, 1))
img = img.astype(np.float32)
img = (img/255. - self.mean) / self.std
img = img.transpose([2, 0, 1])
return img, idx
这段代码看着很复杂其实很简单:
- 首先
class OCRDataset(data.Dataset):
这句话继承了pytorch中的模板类Dataset,该模板要求重载__getitem__方法和__len__方法。 - 该类实现了三个函数__init__,__getitem__方法和__len__方法,第一个函数就是初始化,把配置文件中关于数据的部分全载入了。
- __getitem__方法实现了对载入的数据进行预处理。
- __len__方法实现了获取数据的长度。
在__init__函数中最后拿到了self.labels = []
他的数据形式就是:
self.labels = [{“img.png”:“abcd”},{“img2.png”:“abcdrffff”}…]
就是把路径和标签存在了字典里,字典用列表包着。
测试
在train.py中加入测试代码:
import os
sys.path.append(os.getcwd())
import argparse
import model.model as crnn
import torch
import torch.optim as optim
from utils.utils import load_yml
from data.dataset import OCRDataset
def parse_arg():
parser = argparse.ArgumentParser(description="train crnn")
parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)
args = parser.parse_args()
config = load_yml(args.cfg)
return config
if __name__ == "__main__":
config = parse_arg()
print(config)
train_dataset = OCRDataset(config)
train_loader = data.DataLoader(
dataset=train_dataset,
batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
shuffle=config.TRAIN.SHUFFLE,
num_workers=config.WORKERS,
pin_memory=config.PIN_MEMORY,
)
# get device
if torch.cuda.is_available():
device = torch.device("cuda:{}".format(config.GPUID))
else:
device = torch.device("cpu:0")
for i, (inp, idx) in enumerate(train_loader):
inp = inp.to(device)
print("inp",inp[0].cpu().detach().numpy(),inp[0].cpu().detach().numpy().shape)
exit(-1)# 这里就测试打印一个batch然后退出程序
输出结果:
数据加载完成接着就是搭建模型了~