一些个人理解,记录参考
// An highlighted block
import tensorflow as tf
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
print('读入初始训练图像, train_images.shape:', train_images.shape)
train_images = train_images[0:5000, :, :]
print('取前5000个训练图像, train_images.shape:', train_images.shape)
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') # 变成4维
print('升为4维后,train_images.shape:', train_images.shape)
输出结果
读入初始训练图像,train_images.shape: (60000, 28, 28)
取前6000个训练图像,train_images.shape: (5000, 28, 28)
升为4维后,train_images.shape: (5000, 28, 28, 1)
BATCH_SIZE = 512
BUFFER_SIZE = 5000
print('==========')
//#datasets = tf.data.Dataset.from_tensor_slices(train_images, train_labels)
datasets = tf.data.Dataset.from_tensor_slices(train_images)
print('datasets:',datasets)
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
print('datasets:',datasets)
times = 1 //批数
total = 0 //图像数
for item in datasets: //遍历训练数据,相当于一个epoch
print(f'=======当前批数:{times}========')
print(item.shape)
batch_count = item.shape[0] //# batch_size设置为512,但是不一定能整除,实际最后一个batch达不到512
total += batch_count
times += 1
print('扫过数据数量:', total)
输出结果:
datasets: <TensorSliceDataset shapes: (28, 28, 1), types: tf.float32>
datasets: <BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>
=======当前批数:1========
(512, 28, 28, 1)
=======当前批数:2========
(512, 28, 28, 1)
=======当前批数:3========
(512, 28, 28, 1)
=======当前批数:4========
(512, 28, 28, 1)
=======当前批数:5========
(512, 28, 28, 1)
=======当前批数:6========
(512, 28, 28, 1)
=======当前批数:7========
(512, 28, 28, 1)
=======当前批数:8========
(512, 28, 28, 1)
=======当前批数:9========
(512, 28, 28, 1)
=======当前批数:10========
(392, 28, 28, 1)
扫过数据数量: 5000
Process finished with exit code 0
个人理解,datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)这一步,先将BUFFER_SIZE范围内数据打乱,再按BATCHSIZE大小将训练数据按批分好,每一批的shape为:(batch_size, 28, 28, 1)。datasets数量仍为训练总数5000。for item in datasets:,item即为每一个batch的训练数据。