tensorflow2.0 datasets.shuffle(buffer_size).batch(batch_size)

一些个人理解,记录参考

// An highlighted block
import tensorflow as tf

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
print('读入初始训练图像, train_images.shape:', train_images.shape)

train_images = train_images[0:5000, :, :]
print('取前5000个训练图像, train_images.shape:', train_images.shape)

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')  # 变成4维
print('升为4维后,train_images.shape:', train_images.shape)

输出结果

读入初始训练图像,train_images.shape: (60000, 28, 28)
取前6000个训练图像,train_images.shape: (5000, 28, 28)
升为4维后,train_images.shape: (5000, 28, 28, 1)
BATCH_SIZE = 512
BUFFER_SIZE = 5000

print('==========')
//#datasets = tf.data.Dataset.from_tensor_slices(train_images, train_labels)
datasets = tf.data.Dataset.from_tensor_slices(train_images)
print('datasets:',datasets)

datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
print('datasets:',datasets)

times = 1  //批数
total = 0  //图像数
for item in datasets:  //遍历训练数据,相当于一个epoch
    print(f'=======当前批数:{times}========')
    print(item.shape)

    batch_count = item.shape[0]  //# batch_size设置为512,但是不一定能整除,实际最后一个batch达不到512
    total += batch_count
    times += 1

print('扫过数据数量:', total)

输出结果:

datasets: <TensorSliceDataset shapes: (28, 28, 1), types: tf.float32>
datasets: <BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>
=======当前批数:1========
(512, 28, 28, 1)
=======当前批数:2========
(512, 28, 28, 1)
=======当前批数:3========
(512, 28, 28, 1)
=======当前批数:4========
(512, 28, 28, 1)
=======当前批数:5========
(512, 28, 28, 1)
=======当前批数:6========
(512, 28, 28, 1)
=======当前批数:7========
(512, 28, 28, 1)
=======当前批数:8========
(512, 28, 28, 1)
=======当前批数:9========
(512, 28, 28, 1)
=======当前批数:10========
(392, 28, 28, 1)
扫过数据数量: 5000

Process finished with exit code 0

个人理解,datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)这一步,先将BUFFER_SIZE范围内数据打乱,再按BATCHSIZE大小将训练数据按批分好,每一批的shape为:(batch_size, 28, 28, 1)。datasets数量仍为训练总数5000。for item in datasets:,item即为每一个batch的训练数据。

上一篇:OPENCV颜色检测——库函数版本


下一篇:Longhorn 云原生容器分布式存储 - Air Gap 安装