这里当前目录下已经有fruits-360
这个数据集. 关于调用数据集的方法可以查看我另一篇文章.
准备
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.preprocessing.image import load_img, img_to_array, array_to_img, ImageDataGenerator
创建 Generator
创建 ImageDataGenerator
. 由于这个数据集足够大, 所以不需要进行 image augmentation.
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
"fruits-360/Training",
target_size=(100, 100),
batch_size=32,
class_mode='categorical')
validation_generator = test_datagen.flow_from_directory(
"fruits-360/Test",
target_size=(100, 100),
batch_size=32,
class_mode='categorical')
运行后看到如下输出表示创建成功.
Found 67692 images belonging to 131 classes.
Found 22688 images belonging to 131 classes.
模型
这里使用的是 Xception 模型.
from tensorflow.keras.applications.xception import preprocess_input
from tensorflow.keras.applications.xception import decode_predictions
from tensorflow.keras.applications.xception import Xception
tf.keras.backend.clear_session()
base_model = tf.keras.applications.Xception(
weights='imagenet', # Load weights pre-trained on ImageNet.
input_shape=(100, 100, 3),
include_top=False) # Do not include the ImageNet classifier at the top.
input_layer = tf.keras.Input(shape=(100, 100, 3))
base_model.trainable = False
# x = data_augmentation(input_layer)
x = base_model(input_layer, training = False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(64, activation = 'relu')(x)
x = tf.keras.layers.Dropout(0.2)(x) # Regularize with dropout
output_layer = tf.keras.layers.Dense(131, activation = 'softmax')(x)
model = tf.keras.Model(input_layer, output_layer)
model.summary()
base_model.trainable = False
将会冻结 Xception 模型的权重, 在训练中不会被更新. 即使用已经训练好的权重.
x = base_model(input_layer, training = False)
中training=False
可以确保 base模型处于 Inference phase, 而不是Training phase.
opt = tf.keras.optimizers.Adam(learning_rate=0.01)
model.compile(loss='categorical_crossentropy', optimizer=opt, metrics = ['accuracy'])
model.fit(train_generator, epochs=5, steps_per_epoch = 67692//32, validation_data=validation_generator)