notice:tensorflow 1.14.0版本
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
d = np.arange(0, 60).reshape([6, 10])
data = tf.data.Dataset.from_tensor_slices(d) // numpy 转为 tensor
#从data数据集中按顺序抽取buffer_size个样本放在buffer中,然后打乱buffer中的样本
# buffer中样本个数不足buffer_size,继续从data数据集中安顺序填充至buffer_size,
# 此时会再次打乱
data = data.shuffle(buffer_size=3)
## 每次从buffer中抽取4个样本
data = data.batch(4)
# 将data数据集重复,其实就是2个epoch数据集
data = data.repeat(2)
# 构造获取数据的迭代器
iters = data.make_one_shot_iterator()
# 每次从迭代器中获取一批数据
batch = iters.get_next()
sess = tf.Session()
while True:
try:
print(sess.run(batch))
except tf.errors.OutOfRangeError:
break
res:
[[20 21 22 23 24 25 26 27 28 29]
[ 0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[50 51 52 53 54 55 56 57 58 59]]
[[40 41 42 43 44 45 46 47 48 49]
[30 31 32 33 34 35 36 37 38 39]]
[[ 0 1 2 3 4 5 6 7 8 9]
[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]
[10 11 12 13 14 15 16 17 18 19]]
[[50 51 52 53 54 55 56 57 58 59]
[40 41 42 43 44 45 46 47 48 49]]