需求
无论是tensorflow,还是keras,抑或pytorch的torchvision提供的datasets库,都无法提供足够灵活、足够简洁的Dataset类。
因此,我打算自己写一个简单的易于扩展的单节点数据集工具类。
分析
- 我想要的接口就是传入一个整理好的图片的层级目录所在的路径,以及我想要训练的总epoch数,我就能像使用迭代器一样从这个数据集实例中连续不断的获取样本。
- 我可以指定每次迭代器返回的样本的批大小(batch size);
- 我可以在创建数据集实例的时候通过flag控制生成的实例的随机性;
- 当dataset实例迭代完所有epoch的所有batch后,应当发出信号告知我训练结束了;
好,开干!
代码
import numpy as np
import torch
import torchvision.datasets as D
import torchvision.transforms as T
import glob
import cv2
from collections import Iterator
class PairedGraySet(Iterator):
def __init__(self, root_path, num_epoc, batch_size, shuffle=True, max_image_size=(640, 480)):
self.batch_size = batch_size
self.num_epoc = num_epoc
self.shuffle = shuffle
self.epoc = 0
self.iter = 0
self.im1 = None
self.im2 = None
# load images
files_main = glob.glob(root_path + '/*/*/*_siam_main.png')
files_main.sort()
files_aux = glob.glob(root_path + '/*/*/*_siam_aux.png')
files_aux.sort()
assert len(files_aux) == len(files_main)
assert len(files_main) > batch_size
# try to load all images into memory
n_pairs = len(files_main)
self.images_main = [None]*n_pairs
self.images_aux = [None]*n_pairs
for i in range(len(files_main)):
main_ = cv2.imread(files_main[i], -1) # gray image
aux_ = cv2.imread(files_aux[i], -1) # gray image
assert len(main_.shape)==2
assert len(aux_.shape)==2
assert main_.shape[0] == aux_.shape[0]
assert main_.shape[1] == aux_.shape[1]
self.images_main[i] = cv2.resize(main_, max_image_size, cv2.INTER_LINEAR)
self.images_aux[i] = cv2.resize(aux_, max_image_size, cv2.INTER_LINEAR)
# read images from root path, and determine how many
# batches to run till an epoc is running out.
self.num_samples = n_pairs
self.max_iter = self.num_samples // self.batch_size
# the first index sequence
if self.shuffle:
self.ind_seq = np.random.permutation(np.arange(start=0, stop=self.num_samples))
else:
self.ind_seq = np.arange(start=0, stop=self.num_samples)
# preallocation for speed
h, w = self.images_main[0].shape
self.im1 = np.zeros([self.batch_size, 1, h, w], np.float32)
self.im2 = np.zeros([self.batch_size, 1, h, w], np.float32)
def __next__(self):
cur_epoc = self.epoc
cur_iter = self.iter
if cur_epoc >= self.num_epoc:
print('training is over, return None!')
return cur_epoc, cur_iter, None, None
else:
# make batch of image pairs
beg_ = int(cur_iter*self.batch_size)
end_ = beg_ + self.batch_size
for pair_id in range(beg_, end_):
self.im1[pair_id-beg_, 0, :,:] = self.images_main[self.ind_seq[pair_id]][:,:]
self.im2[pair_id-beg_, 0, :,:] = self.images_aux[self.ind_seq[pair_id]][:,:]
# update state to the next
self.iter += 1
if self.iter == self.max_iter:
self.epoc += 1
self.iter = 0
# shuffle the index sequence
if self.shuffle:
self.ind_seq = np.random.permutation(np.arange(start=0, stop=self.num_samples))
return cur_epoc, cur_iter, self.im1, self.im2
使用示例:
# resolve the package root path
import os
parent = os.path.dirname(os.path.abspath(__file__))
package_root = os.path.dirname(parent)
os.environ['CUDA_VISIBLE_DEVICES']='3'
import sys
sys.path.append(package_root)
from model import StereoSiamNet
from dataset import PairedGraySet
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.optim
def train_ssn(model_path, model_name, data_path):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
#net_ = StereoSiamNet()
net_ = torch.load('ssn_bremen-cross_decoding-150.pth')
net_ = net_.to(device)
#print(net_.dump_info())
loss_fn = nn.SmoothL1Loss(reduction='mean')
learning_rate = 1e-3
#print(net_.parameters)
opt_ = torch.optim.Adam(net_.parameters(), lr=learning_rate)
num_epoch = 2000
batch_size = 8 #8
print_freq = 100 #100
save_freq = 50 # 50
# load dataset
data_loader = PairedGraySet('../Datasets/SSN/Bremen/', num_epoch, batch_size)
net_.train()
counter_ = 0
for i_epo, i_itr, im1, im2 in data_loader:
if im1 is None or im2 is None:
break
counter_ += 1
x1 = torch.Tensor(im1).to(device)
x2 = torch.Tensor(im2).to(device)
x1r, x2r, _, _ = net_(x1, x2)
loss_ = loss_fn(x1, x1r) + loss_fn(x2, x2r)
opt_.zero_grad()
loss_.backward()
opt_.step()
if counter_ % print_freq == 0:
print('Epoch: %03d Iter: %5d Loss %8.5f' % (i_epo, i_itr, loss_.item()))
if i_epo > 0 and (i_epo % save_freq == 0) and (i_itr==0):
torch.save(net_, '%s-%03d.pth' % (model_name, i_epo))
print('model saved.')
torch.save(net_, '%s-%03d.pth' % (model_name, num_epoch))
print('model saved.')
print('training done.')
if __name__ == '__main__':
model_dir = '../Models/SSN/Bremen/'
model_name = 'ssn_bremen-cross_decoding'
data_dir = '../Datasets/SSN/Bremen/'
train_ssn(model_dir, model_name, data_dir)
print('training done!')
print('trained model saved as: ' + os.path.join(model_dir + model_name))
可以看到,数据集的行为表现非常简单,和torchvision的Data Loader很相似,就是一个迭代器。