Tensorflow踩坑系列---数据读取文件队列

一:总括文件读取方式

1.供给数据(Feeding): 由占位符placeholder代替数据,运行时使用feed_dict填入数据

Tensorflow踩坑系列---数据读取文件队列

Tensorflow踩坑系列---数据读取文件队列

2.预加载数据: 数据直接嵌入graph,由graph传入session中运行

 Tensorflow踩坑系列---数据读取文件队列

3.从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据,这就是这篇文将要讲的内容。

前两种方法很方便,但是遇到大型数据的时候就会很吃力,即使是Feeding,中间环节的增加也是不小的开销,比如数据类型转换等等。 最优的方案就是在Graph定义好文件读取的方法,让TF自己去从文件中读取数据,并解码成可使用的样本集。

对于大的数据集很难用numpy数组保存,所以这里介绍一下Tensorflow读取很大数据集的方法:string_input_producer()和slice_input_producer()。

这种直接从文件中读取数据的方式需要设计成Queue的方式才能较好的解决IO瓶颈的问题。
Queue机制有如下三个特点:

(1)producer-consumer pattern(生产消费模式)
(2)独立于主线程执行
(3)异步IO: reader.read(queue) tf.train.batch()

一:string_input_producer队列使用(单个Reader、单文件读取)

import tensorflow as tf
IMAGE_DIR = "./Images/SourceImgs/"
QUEUE_DIR = "./Images/QueueImgs/"
FILELIST = ["1100.jpg","1101.jpg","1102.jpg","1104.jpg","1105.jpg",
           "1110.jpg","1114.jpg","1115.jpg","1116.jpg","1118.jpg"]

Tensorflow踩坑系列---数据读取文件队列

(一)获取文件列表

def getFileList(rootDir=IMAGE_DIR,files=FILELIST):
    fsl = []
    for fn in files:
        fsl.append(IMAGE_DIR+fn)
    return fsl

(二)使用队列读取文件

with tf.Session() as sess:
    files_list = getFileList()
    #string_input_producer产生文件名队列
    filename_queue = tf.train.string_input_producer(files_list,shuffle=True,num_epochs=3)
    #reader从文件名队列中读取数据
    reader = tf.WholeFileReader()
    key,value = reader.read(filename_queue) #返回文件名和文件内容
    
    sess.run(tf.local_variables_initializer()) #初始化上面的局部变量
    
    #启动start_queue_runners之后,才会开始填充队列
    threads = tf.train.start_queue_runners(sess=sess)
    i = 1
    while True:
        try:
            image_data = sess.run(value)
            with open(QUEUE_DIR+"%d.jpg"%i,"wb") as f:
                f.write(image_data)
            i+=1
        except BaseException:
            print("read all files, numbers:%d"%i)
            break

(三)参数说明

tf.train.string_input_producer(files_list,shuffle=False,num_epochs=2)

shuffle=False:表示按序获得文件

num_epochs=2:表示会遍历两遍全部文件,当我们不设置数值的时候,表示我们可以一直遍历下去,会循环所有文件

Tensorflow踩坑系列---数据读取文件队列

tf.train.string_input_producer(files_list,shuffle=True,num_epochs=3)

shuffle=False:表示打乱顺序获得文件(是本轮所有文件列表中乱序,不是全局)

num_epochs=2:表示会遍历三遍全部文件

Tensorflow踩坑系列---数据读取文件队列

二:string_input_producer队列使用(单个Reader、批文件读取)

import tensorflow as tf
IMAGE_DIR = "./Images/SourceImgs/"
QUEUE_DIR = "./Images/QueueImgs/"
FILELIST = ["1100.jpg","1101.jpg","1102.jpg","1104.jpg","1105.jpg",
           "1110.jpg","1114.jpg","1115.jpg","1116.jpg","1118.jpg"]
def getFileList(rootDir=IMAGE_DIR,files=FILELIST):
    fsl = []
    for fn in files:
        fsl.append(IMAGE_DIR+fn)
    return fsl

(一)按批次获取文件

files_list = getFileList()
#string_input_producer产生文件名队列
filename_queue = tf.train.string_input_producer(files_list,shuffle=False,num_epochs=1)

def decode_img(fileQueue):
    #reader从文件名队列中读取数据
    reader = tf.WholeFileReader()
    key,value = reader.read(fileQueue) #返回文件名和文件内容
    return value #返回一个文件

img = decode_img(filename_queue)

image_batch = tf.train.batch([img],batch_size=8,num_threads=2,allow_smaller_final_batch=True) 

(二)线程调用

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) #初始化上面的全局变量
    sess.run(tf.local_variables_initializer()) #初始化上面的局部变量
    
    coord = tf.train.Coordinator()
    #启动start_queue_runners之后,才会开始填充队列
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)
    j = 1
    try:
        while not coord.should_stop():
            images_data = sess.run(image_batch)
            print(images_data.shape)
            for img_data in images_data:
                with open(QUEUE_DIR+"%d.jpg"%j,"wb") as f:
                    f.write(img_data)
                j+=1
    except BaseException:
            print("read all files")
    finally:
        coord.request_stop() #将读取文件的线程关闭
    coord.join(threads) #线程回收,将读取文件的子线程加入主线程

Tensorflow踩坑系列---数据读取文件队列

(三)参数说明

tf.train.batch([img],batch_size=8,num_threads=2,allow_smaller_final_batch=True)

使用tf.train.batch,按序获取:

batch_size每一个批次大小为8,
num_threads使用2线程读取数据,虽然这里只有一个Reader,但可以设置多线程,相应增加线程数会提高读取速度,但并不是线程越多越好。
allow_smaller_final_batch,默认为false,剩余数据小于batch_size则会被丢弃。

tf.train.shuffle_batch() 将队列中数据打乱后再读取出来,其他与batch方法类似。

需要设置:
capacity:队列中元素的最大数量。
min_after_dequeue:出队后队列中元素的最小数量,用于确保元素的混合级别。

 

补充:

TensorFlow学习--tf.train.batch与tf.train.shuffle_batch

tf.train.string_input_producer()和tf.train.slice_input_producer()

string_input_producer:

加载图片的reader是reader = tf.WholeFileReader()

key,value = reader.read(path_queue)其中key是文件名,value是byte类型的文件流二进制。

slice_input_producer:

加载图片的reader使用tf.read_file(filename)直接读取。这是两者的一个不同之处!!!

TensorFlow基础3:数据读取的三种方式

上一篇:SpringCloud使用Kafka消费者


下一篇:RocketMQ(一)原理和实战!