tensorflow 训练过程,使用tf.data迭代数据

tensorflow 训练过程中,从csv(也可以是 txt) 读取数据的流程实现

代码主要有以下几部分:

  1. 创建dataset create_dataset
    1. dataset 中,对读取到的一行 进行的操作 decode_line
  2. 创建 训练与验证 迭代器 create_iterator
  3. 创建 预测 迭代器 create_iterator_for_predict
  4. 数据预处理过程
    1. 特征与标签的预处理 reprocessing_feature_label
    2. 特征的预处理(可以单独对特征进行预处理,不会打乱特征与标签的对应关系) reprocessing_feature
  5. 训练过程数据迭代的测试 train_validation_test
  6. 预测过程数据迭代测试 predict_validation_test

# encoding: utf-8
"""
@author: mry
@contact: 1037129120@qq.com
@time: 2020/03/02 23:50
@file: dataset_FromStringHandle.py
@desc: 
        训练集 使用 make_one_shot_iterator 
            并不是只对数据集循环一次 而是循环dataset指定的次数
            如果创建dataset时不指定循环次数,那么就可以无限循环
        验证集 使用 make_initializable_iterator 

        测试集 使用 make_one_shot_iterator 
好处: 验证集 不会打断训练集 
"""

import pandas as pd 
import numpy as np 
import tensorflow as tf 
import time


# dataset 中,对读取到的一行 进行的操作 
#   主要是指定每一列的数据类型  以及如何切分 features 和 labels
def decode_line(line,len=2):

    record_defaults = [[0.0] for i in range(len)]
    items = tf.decode_csv(line, record_defaults)
    features_0 = items[0]
    labels = items[1]
    return features_0,labels


# 创建dataset 
#   主要是指定 从哪里读取文件 指定批次,迭代次数, 是否打乱,是否对文件进行过滤等操作
def create_dataset(filename,batch_size=32,is_shuffle=False,n_repeats=0):
    # # # 只有第一个文件去掉第一行
    # dataset = tf.data.TextLineDataset(filename).skip(1)   #.filter(lambda line: tf.not_equal(tf.substr(line,0,1),"0"))
    # # 所有的文件都去除第一行
    dataset = tf.data.Dataset.from_tensor_slices(filename)
    dataset = dataset.flat_map(
        lambda f: (
            tf.data.TextLineDataset(f)
            .skip(1)       #.filter(lambda line: tf.not_equal(tf.substr(line,0,1),"#"))
            ))

    if n_repeats >0:
        dataset =dataset.repeat(n_repeats)
    else:
        dataset =dataset.repeat()
    # dataset = dataset.map(decode_line)
    dataset = dataset.map(lambda x:decode_line(x,len=2))

    if is_shuffle:
        dataset = dataset.shuffle(10000)
    dataset = dataset.batch(batch_size)
    return dataset


# 创建 训练与验证 迭代器
def create_iterator(training_filenames=None,validation_filenames=None,
    handle = tf.placeholder(tf.string, shape=[])):
    '''
        __doc__:
            创建训练集+验证集 的迭代器
        Args:
            training_filenames:
            validation_filenames:

        Returns:
            next_element: 下一个batch 
            training_iterator: 初始化为训练集迭代器
            validation_iterator: 初始化为验证集迭代器
    '''
    # 训练集的dataset
    training_dataset = create_dataset(training_filenames,batch_size=10,is_shuffle=False)    
    # 测试集的dataset
    validation_dataset = create_dataset(validation_filenames,batch_size=5,is_shuffle=False)
    training_iterator = training_dataset.make_one_shot_iterator()
    validation_iterator = validation_dataset.make_initializable_iterator()  

    # # 操作占位符,控制 train or validation
    # handle = tf.placeholder(tf.string, shape=[])
    # iterator 是一个由handle控制的迭代器,可以切换为 训练集迭代器 或 验证集迭代器
    iterator = tf.data.Iterator.from_string_handle(
        handle, training_dataset.output_types, training_dataset.output_shapes)
    next_element = iterator.get_next()

    return next_element,training_iterator,validation_iterator

# 特征与标签的预处理过程
def reprocessing_feature_label(features_0,labels):
    features_0 = tf.reshape(labels,[-1,1])
    labels = tf.reshape(labels,[-1])
    return features_0,labels

# 特征的预处理过程 
def reprocessing_feature(features_0):

    features_0 = tf.reshape(features_0,[-1,1,1])
    return features_0


# 创建预测的迭代器
def create_iterator_for_predict(filenames=None):
    '''
    __doc__:
        创建 评价集 或 预测集 的迭代器
        预测集 如果没有结果标签列(labels),需要在csv中提前添加labels列 
    '''
    predict_dataset = create_dataset(filenames,batch_size=5,is_shuffle=False,n_repeats=1)
    predict_iterator = predict_dataset.make_one_shot_iterator()
    next_element = predict_iterator.get_next()
    return next_element


# 训练过程测试
def train_validation_test():
    training_filenames = ['./t1.csv','./t3.csv']
    validation_filenames = ['./t2.csv']

    sess = tf.Session()
    handle = tf.placeholder(tf.string, shape=[])
    next_element,training_iterator,validation_iterator = create_iterator(training_filenames,validation_filenames,handle)
    
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())   

    features_0, labels=next_element
    # features_0, labels = reprocessing_feature_label(features_0,labels)
    features_0 = reprocessing_feature(features_0)
    for _ in range(50):
        # Initialize an iterator over the training dataset.
        
        for _ in range(10):
            f_0, l=sess.run([features_0,labels],feed_dict={handle:training_handle})
            print(f_0)
            print(l)
        print('>'*20)
        # Initialize an iterator over the validation dataset.
        sess.run(validation_iterator.initializer)
        for _ in range(5):
            f_0, ne=sess.run([features_0,labels],feed_dict={handle:validation_handle})
            print(f_0)
            print(ne)
            # print(ne)
        print('>'*50)
        time.sleep(0.5)


# 预测过程测试
def predict_test():

    filenames = ['./t2.csv']
    g_xyk_predict=tf.Graph()
    with g_xyk_predict.as_default():

        next_element = create_iterator_for_predict(filenames)
        features_0,labels=next_element
        features_0,labels = reprocessing_feature_label(features_0,labels)
        # 真实情况下,该位置放模型
        # 例如 prediction = rnn(features_0)
        sess = tf.Session()

        # 真实情况下,该位置从check point 或者 pb 模型加载训练的模型
        
        result_list = []
        while True:
            try:
                l = sess.run(labels)
                print(l)
                print('>'*30)
                result_list+=l.tolist()
            except tf.errors.OutOfRangeError:
                break
    print(result_list)
    return result_list


if __name__ =='__main__':
    # 训练 与 验证 的过程
    train_validation_test()
    # 预测的过程
    predict_test()

读取的数据

./t1.csv ./t2.csv ./t3.csv

有两列 features和labels(列名是什么没关系)

【此处有个图片没上传成功】用表格代替了

features label
0 0
1 1
2 2
3 3

运行结果

[[[0.]]
 [[1.]]
 [[2.]]
 [[3.]]
 [[4.]]
 [[5.]]
 [[6.]]
 [[7.]]
 [[8.]]
 [[9.]]]
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
[[[10.]]
 [[11.]]
 [[12.]]
 [[13.]]
 [[14.]]
 [[15.]]
 [[16.]]
 [[17.]]
....
[0. 1. 2. 3. 4.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[5. 6. 7. 8. 9.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[10. 11. 12. 13. 14.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[15. 16. 17. 18. 19.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[20. 21. 22. 23. 24.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[25. 26. 27. 28. 29.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[30. 31. 32. 33. 34.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[35. 36. 37. 38. 39.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[40. 41. 42. 43. 44.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
....

 

上一篇:SprintBoot使用Validation


下一篇:java validation 验证器