这俩方法都是从队列中批量获取元素,常用于样本的批量获取;
这俩 API 非常反人类,有些参数我还没搞懂,时间关系,先学习常规用法吧
batch
从队列中获取指定个数的元素组成一个 batch
def batch(tensors, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None): """Creates batches of tensors in `tensors`."""
tensors:队列
batch_size:获取元素个数
capacity:队列容量 【没搞懂有啥用】
label = np.asarray(range(0, 100)) # label = tf.cast(label, tf.int32) input_queue = tf.train.slice_input_producer([label], shuffle=False) label_batch = tf.train.batch(input_queue, batch_size=19, num_threads=1, capacity=5) with tf.Session() as sess: coord = tf.train.Coordinator() # 线程的协调器 threads = tf.train.start_queue_runners(sess, coord) # 开始在图表中收集队列运行器 for j in range(8): out = sess.run([label_batch]) print(out) coord.request_stop() coord.join(threads) # [array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18])] # [array([19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37])] # [array([38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56])] # [array([57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75])] # [array([76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94])] # [array([95, 96, 97, 98, 99, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])] # [array([14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32])] # [array([33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51])]
shuffle_batch
从队列中随机获取指定个数的元素组成一个 batch
def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, num_threads=1, seed=None, enqueue_many=False, shapes=None, allow_smaller_final_batch=False, shared_name=None, name=None): """Creates batches by randomly shuffling tensors."""
capacity:队列容量,这个参数一定要比 min_after_dequeue 大
推荐值为
- capacity = (min_after_dequeue + (num_threads + a small safety margin ∗ batch s ize)
min_after_dequeue:当一次出列操作完成后,队列中元素的最小数量,往往用于定义元素的混合级别;
定义了随机取样的缓冲区大小,此参数越大表示 更大级别的混合 但是 会导致启动更加缓慢,并且会占用更多的内存
images = np.random.random([100,2]) label = np.asarray(range(0, 100)) images = tf.cast(images, tf.float32) label = tf.cast(label, tf.int32) input_queue = tf.train.slice_input_producer([images, label], shuffle=False) image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=10, num_threads=1, capacity=64, min_after_dequeue=10) with tf.Session() as sess: coord = tf.train.Coordinator() # 线程的协调器 threads = tf.train.start_queue_runners(sess, coord) # 开始在图表中收集队列运行器 for _ in range(5): image_batch_v, label_batch_v = sess.run([image_batch, label_batch]) print(image_batch_v, label_batch_v) coord.request_stop() coord.join(threads)
参考资料:
https://blog.csdn.net/akadiao/article/details/79645221
https://blog.csdn.net/u013555719/article/details/77679964