利用Python 脚本生成 .h5 文件
1 import os, json, argparse
2 from threading import Thread
3 from Queue import Queue
4
5 import numpy as np
6 from scipy.misc import imread, imresize
7 import h5py
8
9 """
10 Create an HDF5 file of images for training a feedforward style transfer model.
11 """
12
13 parser = argparse.ArgumentParser()
14 parser.add_argument('--train_dir', default='/media/wangxiao/WangXiao_Dataset/CoCo/train2014')
15 parser.add_argument('--val_dir', default='/media/wangxiao/WangXiao_Dataset/CoCo/val2014')
16 parser.add_argument('--output_file', default='/media/wangxiao/WangXiao_Dataset/CoCo/coco-256.h5')
17 parser.add_argument('--height', type=int, default=256)
18 parser.add_argument('--width', type=int, default=256)
19 parser.add_argument('--max_images', type=int, default=-1)
20 parser.add_argument('--num_workers', type=int, default=2)
21 parser.add_argument('--include_val', type=int, default=1)
22 parser.add_argument('--max_resize', default=16, type=int)
23 args = parser.parse_args()
24
25
26 def add_data(h5_file, image_dir, prefix, args):
27 # Make a list of all images in the source directory
28 image_list = []
29 image_extensions = {'.jpg', '.jpeg', '.JPG', '.JPEG', '.png', '.PNG'}
30 for filename in os.listdir(image_dir):
31 ext = os.path.splitext(filename)[1]
32 if ext in image_extensions:
33 image_list.append(os.path.join(image_dir, filename))
34 num_images = len(image_list)
35
36 # Resize all images and copy them into the hdf5 file
37 # We'll bravely try multithreading
38 dset_name = os.path.join(prefix, 'images')
39 dset_size = (num_images, 3, args.height, args.width)
40 imgs_dset = h5_file.create_dataset(dset_name, dset_size, np.uint8)
41
42 # input_queue stores (idx, filename) tuples,
43 # output_queue stores (idx, resized_img) tuples
44 input_queue = Queue()
45 output_queue = Queue()
46
47 # Read workers pull images off disk and resize them
48 def read_worker():
49 while True:
50 idx, filename = input_queue.get()
51 img = imread(filename)
52 try:
53 # First crop the image so its size is a multiple of max_resize
54 H, W = img.shape[0], img.shape[1]
55 H_crop = H - H % args.max_resize
56 W_crop = W - W % args.max_resize
57 img = img[:H_crop, :W_crop]
58 img = imresize(img, (args.height, args.width))
59 except (ValueError, IndexError) as e:
60 print filename
61 print img.shape, img.dtype
62 print e
63 input_queue.task_done()
64 output_queue.put((idx, img))
65
66 # Write workers write resized images to the hdf5 file
67 def write_worker():
68 num_written = 0
69 while True:
70 idx, img = output_queue.get()
71 if img.ndim == 3:
72 # RGB image, transpose from H x W x C to C x H x W
73 imgs_dset[idx] = img.transpose(2, 0, 1)
74 elif img.ndim == 2:
75 # Grayscale image; it is H x W so broadcasting to C x H x W will just copy
76 # grayscale values into all channels.
77 imgs_dset[idx] = img
78 output_queue.task_done()
79 num_written = num_written + 1
80 if num_written % 100 == 0:
81 print 'Copied %d / %d images' % (num_written, num_images)
82
83 # Start the read workers.
84 for i in xrange(args.num_workers):
85 t = Thread(target=read_worker)
86 t.daemon = True
87 t.start()
88
89 # h5py locks internally, so we can only use a single write worker =(
90 t = Thread(target=write_worker)
91 t.daemon = True
92 t.start()
93
94 for idx, filename in enumerate(image_list):
95 if args.max_images > 0 and idx >= args.max_images: break
96 input_queue.put((idx, filename))
97
98 input_queue.join()
99 output_queue.join()
100
101
102
103 if __name__ == '__main__':
104
105 with h5py.File(args.output_file, 'w') as f:
106 add_data(f, args.train_dir, 'train2014', args)
107
108 if args.include_val != 0:
109 add_data(f, args.val_dir, 'val2014', args)