tf.train.batch( tensors, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None )
函数功能:利用一个tensor的列表或字典来获取一个batch数据
参数介绍:
- tensors:一个列表或字典的tensor用来进行入队
- batch_size:设置每次从队列中获取出队数据的数量
- num_threads:用来控制入队tensors线程的数量,如果num_threads大于1,则batch操作将是非确定性的,输出的batch可能会乱序
- capacity:一个整数,用来设置队列中元素的最大数量
- enqueue_many:在tensors中的tensor是否是单个样本
- shapes:可选,每个样本的shape,默认是tensors的shape
- dynamic_pad:Boolean值.允许输入变量的shape,出队后会自动填补维度,来保持与batch内的shapes相同
- allow_samller_final_batch:可选,Boolean值,如果为True队列中的样本数量小于batch_size时,出队的数量会以最终遗留下来的样本进行出队,如果为Flalse,小于batch_size的样本不会做出队处理
- shared_name:可选,通过设置该参数,可以对多个会话共享队列
- name:可选,操作的名字
从数组中每次获取一个batch_size的数据
import numpy as np import tensorflow as tf def next_batch(): datasets = np.asarray(range(0,20)) input_queue = tf.train.slice_input_producer([datasets],shuffle=False,num_epochs=1) data_batchs = tf.train.batch(input_queue,batch_size=5,num_threads=1, capacity=20,allow_smaller_final_batch=False) return data_batchs if __name__ == "__main__": data_batchs = next_batch() sess = tf.Session() sess.run(tf.initialize_local_variables()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess,coord) try: while not coord.should_stop(): data = sess.run([data_batchs]) print(data) except tf.errors.OutOfRangeError: print("complete") finally: coord.request_stop() coord.join(threads) sess.close()
输出结果:
[array ([0, 1, 2, 3, 4])] [array ([5, 6, 7, 8, 9])] [array ([10, 11, 12, 13, 14])] [array ([15, 16, 17, 18, 19])]