tensorflow中 batch, shuffle, repeat

notice:tensorflow 1.14.0版本

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

d = np.arange(0, 60).reshape([6, 10])

data = tf.data.Dataset.from_tensor_slices(d) // numpy 转为 tensor

#从data数据集中按顺序抽取buffer_size个样本放在buffer中,然后打乱buffer中的样本
# buffer中样本个数不足buffer_size,继续从data数据集中安顺序填充至buffer_size,
# 此时会再次打乱
data = data.shuffle(buffer_size=3)

## 每次从buffer中抽取4个样本
data = data.batch(4)

# 将data数据集重复,其实就是2个epoch数据集
data = data.repeat(2)

# 构造获取数据的迭代器
iters = data.make_one_shot_iterator()

# 每次从迭代器中获取一批数据
batch = iters.get_next()

sess = tf.Session()
while True:
    try:
        print(sess.run(batch))
    except tf.errors.OutOfRangeError:
        break

res:

[[20 21 22 23 24 25 26 27 28 29]
 [ 0  1  2  3  4  5  6  7  8  9]
 [10 11 12 13 14 15 16 17 18 19]
 [50 51 52 53 54 55 56 57 58 59]]
[[40 41 42 43 44 45 46 47 48 49]
 [30 31 32 33 34 35 36 37 38 39]]

[[ 0  1  2  3  4  5  6  7  8  9]
 [20 21 22 23 24 25 26 27 28 29]
 [30 31 32 33 34 35 36 37 38 39]
 [10 11 12 13 14 15 16 17 18 19]]
[[50 51 52 53 54 55 56 57 58 59]
 [40 41 42 43 44 45 46 47 48 49]]

上一篇:TensorFlow实践笔记


下一篇:TensorFlow自定义损失函数