tensorflow 训练过程中,从csv(也可以是 txt) 读取数据的流程实现
代码主要有以下几部分:
- 创建dataset create_dataset
- dataset 中,对读取到的一行 进行的操作 decode_line
- 创建 训练与验证 迭代器 create_iterator
- 创建 预测 迭代器 create_iterator_for_predict
- 数据预处理过程
- 特征与标签的预处理 reprocessing_feature_label
- 特征的预处理(可以单独对特征进行预处理,不会打乱特征与标签的对应关系) reprocessing_feature
- 训练过程数据迭代的测试 train_validation_test
- 预测过程数据迭代测试 predict_validation_test
# encoding: utf-8
"""
@author: mry
@contact: 1037129120@qq.com
@time: 2020/03/02 23:50
@file: dataset_FromStringHandle.py
@desc:
训练集 使用 make_one_shot_iterator
并不是只对数据集循环一次 而是循环dataset指定的次数
如果创建dataset时不指定循环次数,那么就可以无限循环
验证集 使用 make_initializable_iterator
测试集 使用 make_one_shot_iterator
好处: 验证集 不会打断训练集
"""
import pandas as pd
import numpy as np
import tensorflow as tf
import time
# dataset 中,对读取到的一行 进行的操作
# 主要是指定每一列的数据类型 以及如何切分 features 和 labels
def decode_line(line,len=2):
record_defaults = [[0.0] for i in range(len)]
items = tf.decode_csv(line, record_defaults)
features_0 = items[0]
labels = items[1]
return features_0,labels
# 创建dataset
# 主要是指定 从哪里读取文件 指定批次,迭代次数, 是否打乱,是否对文件进行过滤等操作
def create_dataset(filename,batch_size=32,is_shuffle=False,n_repeats=0):
# # # 只有第一个文件去掉第一行
# dataset = tf.data.TextLineDataset(filename).skip(1) #.filter(lambda line: tf.not_equal(tf.substr(line,0,1),"0"))
# # 所有的文件都去除第一行
dataset = tf.data.Dataset.from_tensor_slices(filename)
dataset = dataset.flat_map(
lambda f: (
tf.data.TextLineDataset(f)
.skip(1) #.filter(lambda line: tf.not_equal(tf.substr(line,0,1),"#"))
))
if n_repeats >0:
dataset =dataset.repeat(n_repeats)
else:
dataset =dataset.repeat()
# dataset = dataset.map(decode_line)
dataset = dataset.map(lambda x:decode_line(x,len=2))
if is_shuffle:
dataset = dataset.shuffle(10000)
dataset = dataset.batch(batch_size)
return dataset
# 创建 训练与验证 迭代器
def create_iterator(training_filenames=None,validation_filenames=None,
handle = tf.placeholder(tf.string, shape=[])):
'''
__doc__:
创建训练集+验证集 的迭代器
Args:
training_filenames:
validation_filenames:
Returns:
next_element: 下一个batch
training_iterator: 初始化为训练集迭代器
validation_iterator: 初始化为验证集迭代器
'''
# 训练集的dataset
training_dataset = create_dataset(training_filenames,batch_size=10,is_shuffle=False)
# 测试集的dataset
validation_dataset = create_dataset(validation_filenames,batch_size=5,is_shuffle=False)
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
# # 操作占位符,控制 train or validation
# handle = tf.placeholder(tf.string, shape=[])
# iterator 是一个由handle控制的迭代器,可以切换为 训练集迭代器 或 验证集迭代器
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
return next_element,training_iterator,validation_iterator
# 特征与标签的预处理过程
def reprocessing_feature_label(features_0,labels):
features_0 = tf.reshape(labels,[-1,1])
labels = tf.reshape(labels,[-1])
return features_0,labels
# 特征的预处理过程
def reprocessing_feature(features_0):
features_0 = tf.reshape(features_0,[-1,1,1])
return features_0
# 创建预测的迭代器
def create_iterator_for_predict(filenames=None):
'''
__doc__:
创建 评价集 或 预测集 的迭代器
预测集 如果没有结果标签列(labels),需要在csv中提前添加labels列
'''
predict_dataset = create_dataset(filenames,batch_size=5,is_shuffle=False,n_repeats=1)
predict_iterator = predict_dataset.make_one_shot_iterator()
next_element = predict_iterator.get_next()
return next_element
# 训练过程测试
def train_validation_test():
training_filenames = ['./t1.csv','./t3.csv']
validation_filenames = ['./t2.csv']
sess = tf.Session()
handle = tf.placeholder(tf.string, shape=[])
next_element,training_iterator,validation_iterator = create_iterator(training_filenames,validation_filenames,handle)
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
features_0, labels=next_element
# features_0, labels = reprocessing_feature_label(features_0,labels)
features_0 = reprocessing_feature(features_0)
for _ in range(50):
# Initialize an iterator over the training dataset.
for _ in range(10):
f_0, l=sess.run([features_0,labels],feed_dict={handle:training_handle})
print(f_0)
print(l)
print('>'*20)
# Initialize an iterator over the validation dataset.
sess.run(validation_iterator.initializer)
for _ in range(5):
f_0, ne=sess.run([features_0,labels],feed_dict={handle:validation_handle})
print(f_0)
print(ne)
# print(ne)
print('>'*50)
time.sleep(0.5)
# 预测过程测试
def predict_test():
filenames = ['./t2.csv']
g_xyk_predict=tf.Graph()
with g_xyk_predict.as_default():
next_element = create_iterator_for_predict(filenames)
features_0,labels=next_element
features_0,labels = reprocessing_feature_label(features_0,labels)
# 真实情况下,该位置放模型
# 例如 prediction = rnn(features_0)
sess = tf.Session()
# 真实情况下,该位置从check point 或者 pb 模型加载训练的模型
result_list = []
while True:
try:
l = sess.run(labels)
print(l)
print('>'*30)
result_list+=l.tolist()
except tf.errors.OutOfRangeError:
break
print(result_list)
return result_list
if __name__ =='__main__':
# 训练 与 验证 的过程
train_validation_test()
# 预测的过程
predict_test()
读取的数据
./t1.csv ./t2.csv ./t3.csv
有两列 features和labels(列名是什么没关系)
【此处有个图片没上传成功】用表格代替了
features | label |
0 | 0 |
1 | 1 |
2 | 2 |
3 | 3 |
运行结果
[[[0.]]
[[1.]]
[[2.]]
[[3.]]
[[4.]]
[[5.]]
[[6.]]
[[7.]]
[[8.]]
[[9.]]]
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
[[[10.]]
[[11.]]
[[12.]]
[[13.]]
[[14.]]
[[15.]]
[[16.]]
[[17.]]
....
[0. 1. 2. 3. 4.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[5. 6. 7. 8. 9.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[10. 11. 12. 13. 14.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[15. 16. 17. 18. 19.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[20. 21. 22. 23. 24.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[25. 26. 27. 28. 29.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[30. 31. 32. 33. 34.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[35. 36. 37. 38. 39.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[40. 41. 42. 43. 44.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
....