参考训练图像分类器
文章目录
介绍
图像分类模型具有数百万个参数。从头开始培训它们需要大量带有标签的培训数据和大量计算能力。转移学习是一种技巧,它通过采用一个已经在相关任务上经过训练的模型并将其重新用于新模型中,从而大大简化了这一过程。
该Colab演示了如何使用TensorFlow Hub中经过预训练的TF2 SavedModel进行图像特征提取,并在更大,更通用的ImageNet数据集上进行训练,从而构建Keras模型以对五种花进行分类。可选地,可以与新添加的分类器一起训练(“微调”)特征提取器。
寻找工具?
这是一个TensorFlow编码教程。如果您想要一个仅用于为其构建TensorFlow或TF Lite模型的工具,请查看由PIP软件包tensorflow-hub[make_image_classifier]安装的make_image_classifier命令行工具,或者在此TF Lite合作实验室中查看。
Setup设置
import itertools
import os
import matplotlib.pylab as plt
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
print("TF version:", tf.__version__)
print("Hub version:", hub.__version__)
print("GPU is", "available" if tf.test.is_gpu_available() else "NOT AVAILABLE")
TF version: 2.4.1
Hub version: 0.11.0
WARNING:tensorflow:From <ipython-input-1-0831fa394ed3>:12: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
GPU is available
选择要使用的TF2 SavedModel模块
对于初学者,请使用https://hub.tensorflow.google.cn/google/imagenet/mobilenet_v2_100_224/feature_vector/4。可以在代码中使用相同的URL来标识SavedModel,并在浏览器中使用相同的URL来显示其文档。(请注意,TF1 Hub format的模型在这里不起作用。)
您可以在此处找到更多生成图像特征向量的TF2模型。
有多种可能的模型可以尝试。您需要做的就是在下面的单元格中选择一个不同的,然后继续使用notebook。
model_name = "mobilenet_v3_small_100_224" # @param ['bit_s-r50x1', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'inception_v3', 'inception_resnet_v2', 'mobilenet_v2_100_224', 'mobilenet_v2_130_224', 'mobilenet_v2_140_224', 'mobilenet_v3_large_100_224', 'mobilenet_v3_large_075_224', 'mobilenet_v3_small_100_224', 'mobilenet_v3_small_075_224', 'nasnet_large', 'nasnet_mobile', 'pnasnet_large', 'resnet_v1_50', 'resnet_v1_101', 'resnet_v1_152', 'resnet_v2_50', 'resnet_v2_101', 'resnet_v2_152']
model_handle_map = {
"efficientnet_b0": "https://hub.tensorflow.google.cn/tensorflow/efficientnet/b0/feature-vector/1",
"efficientnet_b1": "https://hub.tensorflow.google.cn/tensorflow/efficientnet/b1/feature-vector/1",
"efficientnet_b2": "https://hub.tensorflow.google.cn/tensorflow/efficientnet/b2/feature-vector/1",
"efficientnet_b3": "https://hub.tensorflow.google.cn/tensorflow/efficientnet/b3/feature-vector/1",
"efficientnet_b4": "https://hub.tensorflow.google.cn/tensorflow/efficientnet/b4/feature-vector/1",
"efficientnet_b5": "https://hub.tensorflow.google.cn/tensorflow/efficientnet/b5/feature-vector/1",
"efficientnet_b6": "https://hub.tensorflow.google.cn/tensorflow/efficientnet/b6/feature-vector/1",
"efficientnet_b7": "https://hub.tensorflow.google.cn/tensorflow/efficientnet/b7/feature-vector/1",
"bit_s-r50x1": "https://hub.tensorflow.google.cn/google/bit/s-r50x1/1",
"inception_v3": "https://hub.tensorflow.google.cn/google/imagenet/inception_v3/feature-vector/4",
"inception_resnet_v2": "https://hub.tensorflow.google.cn/google/imagenet/inception_resnet_v2/feature-vector/4",
"resnet_v1_50": "https://hub.tensorflow.google.cn/google/imagenet/resnet_v1_50/feature-vector/4",
"resnet_v1_101": "https://hub.tensorflow.google.cn/google/imagenet/resnet_v1_101/feature-vector/4",
"resnet_v1_152": "https://hub.tensorflow.google.cn/google/imagenet/resnet_v1_152/feature-vector/4",
"resnet_v2_50": "https://hub.tensorflow.google.cn/google/imagenet/resnet_v2_50/feature-vector/4",
"resnet_v2_101": "https://hub.tensorflow.google.cn/google/imagenet/resnet_v2_101/feature-vector/4",
"resnet_v2_152": "https://hub.tensorflow.google.cn/google/imagenet/resnet_v2_152/feature-vector/4",
"nasnet_large": "https://hub.tensorflow.google.cn/google/imagenet/nasnet_large/feature_vector/4",
"nasnet_mobile": "https://hub.tensorflow.google.cn/google/imagenet/nasnet_mobile/feature_vector/4",
"pnasnet_large": "https://hub.tensorflow.google.cn/google/imagenet/pnasnet_large/feature_vector/4",
"mobilenet_v2_100_224": "https://hub.tensorflow.google.cn/google/imagenet/mobilenet_v2_100_224/feature_vector/4",
"mobilenet_v2_130_224": "https://hub.tensorflow.google.cn/google/imagenet/mobilenet_v2_130_224/feature_vector/4",
"mobilenet_v2_140_224": "https://hub.tensorflow.google.cn/google/imagenet/mobilenet_v2_140_224/feature_vector/4",
"mobilenet_v3_small_100_224": "https://hub.tensorflow.google.cn/google/imagenet/mobilenet_v3_small_100_224/feature_vector/5",
"mobilenet_v3_small_075_224": "https://hub.tensorflow.google.cn/google/imagenet/mobilenet_v3_small_075_224/feature_vector/5",
"mobilenet_v3_large_100_224": "https://hub.tensorflow.google.cn/google/imagenet/mobilenet_v3_large_100_224/feature_vector/5",
"mobilenet_v3_large_075_224": "https://hub.tensorflow.google.cn/google/imagenet/mobilenet_v3_large_075_224/feature_vector/5",
}
model_image_size_map = {
"efficientnet_b0": 224,
"efficientnet_b1": 240,
"efficientnet_b2": 260,
"efficientnet_b3": 300,
"efficientnet_b4": 380,
"efficientnet_b5": 456,
"efficientnet_b6": 528,
"efficientnet_b7": 600,
"inception_v3": 299,
"inception_resnet_v2": 299,
"nasnet_large": 331,
"pnasnet_large": 331,
}
model_handle = model_handle_map.get(model_name)
pixels = model_image_size_map.get(model_name, 224)
print(f"Selected model: {model_name} : {model_handle}")
IMAGE_SIZE = (pixels, pixels)
print(f"Input size {IMAGE_SIZE}")
BATCH_SIZE = 32
Selected model: mobilenet_v3_small_100_224 : https://hub.tensorflow.google.cn/google/imagenet/mobilenet_v3_small_100_224/feature_vector/5
Input size (224, 224)
设置鲜花数据集
输入会根据所选模块的大小进行适当调整。数据集增强(即,每次读取图像时都会出现随机失真random distortions)可改善训练效果,尤其是。进行微调时。
data_dir = tf.keras.utils.get_file(
'flower_photos',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 14s 0us/step
datagen_kwargs = dict(rescale=1./255, validation_split=.20)
dataflow_kwargs = dict(target_size=IMAGE_SIZE, batch_size=BATCH_SIZE,
interpolation="bilinear")
valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
**datagen_kwargs)
valid_generator = valid_datagen.flow_from_directory(
data_dir, subset="validation", shuffle=False, **dataflow_kwargs)
do_data_augmentation = False
if do_data_augmentation:
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
rotation_range=40,
horizontal_flip=True,
width_shift_range=0.2, height_shift_range=0.2,
shear_range=0.2, zoom_range=0.2,
**datagen_kwargs)
else:
train_datagen = valid_datagen
train_generator = train_datagen.flow_from_directory(
data_dir, subset="training", shuffle=True, **dataflow_kwargs)
Found 731 images belonging to 5 classes.
Found 2939 images belonging to 5 classes.
定义模型
所要做的就是feature_extractor_layer在Hub模块的顶部放置一个线性分类器。
为了提高速度,我们从不可训练的feature_extractor_layer开始,但您也可以启用微调以提高准确性。
do_fine_tuning = False
不使用fine_tuning
模型结构
model:
tf.keras.layers.InputLayer
hub.KerasLayer
Dropout
Dense
model.build((None,)+IMAGE_SIZE+(3,))
print("Building model with", model_handle)
model = tf.keras.Sequential([
# Explicitly define the input shape so the model can be properly
#显式地定义输入形状,以便模型能够正确运行
# loaded by the TFLiteConverter
tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),
hub.KerasLayer(model_handle, trainable=do_fine_tuning),
tf.keras.layers.Dropout(rate=0.2),
tf.keras.layers.Dense(train_generator.num_classes,
kernel_regularizer=tf.keras.regularizers.l2(0.0001))
])
model.build((None,)+IMAGE_SIZE+(3,))
model.summary()
Building model with https://hub.tensorflow.google.cn/google/imagenet/mobilenet_v3_small_100_224/feature_vector/5
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
keras_layer (KerasLayer) (None, 1024) 1529968
_________________________________________________________________
dropout (Dropout) (None, 1024) 0
_________________________________________________________________
dense (Dense) (None, 5) 5125
=================================================================
Total params: 1,535,093
Trainable params: 5,125
Non-trainable params: 1,529,968
_________________________________________________________________
训练模型
model.compile(
optimizer=tf.keras.optimizers.SGD(lr=0.005, momentum=0.9),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),
metrics=['accuracy'])
steps_per_epoch = train_generator.samples // train_generator.batch_size
validation_steps = valid_generator.samples // valid_generator.batch_size
hist = model.fit(
train_generator,
epochs=5, steps_per_epoch=steps_per_epoch,
validation_data=valid_generator,
validation_steps=validation_steps).history
对于生成器的每轮训练多少个steps,即多少个batch_size
steps_per_epoch = samples 除以batch_size
Epoch 1/5
91/91 [==============================] - 30s 174ms/step - loss: 1.1899 - accuracy: 0.5899 - val_loss: 0.7321 - val_accuracy: 0.8423
Epoch 2/5
91/91 [==============================] - 15s 160ms/step - loss: 0.6630 - accuracy: 0.8935 - val_loss: 0.7036 - val_accuracy: 0.8651
Epoch 3/5
91/91 [==============================] - 14s 158ms/step - loss: 0.6405 - accuracy: 0.9095 - val_loss: 0.6973 - val_accuracy: 0.8580
Epoch 4/5
91/91 [==============================] - 15s 160ms/step - loss: 0.6143 - accuracy: 0.9156 - val_loss: 0.6817 - val_accuracy: 0.8722
Epoch 5/5
91/91 [==============================] - 15s 160ms/step - loss: 0.5917 - accuracy: 0.9323 - val_loss: 0.6795 - val_accuracy: 0.8778
画图
plt.figure()
plt.ylabel("Loss (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(hist["loss"])
plt.plot(hist["val_loss"])
plt.figure()
plt.ylabel("Accuracy (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(hist["accuracy"])
plt.plot(hist["val_accuracy"])
[<matplotlib.lines.Line2D at 0x7f0f2030b198>]
从验证数据中测试图像上的模型:
执行步骤
定义字符串和index转换的函数
next ,从valid_Generator中获得一个样本 x,y
x,y去除维度
显示plt.imshow()
model.predict预测(image要扩展维度)
获得最大标号index =np.argmax
进行字符串string和index的转换
# 这是一个传入 index 即预测向量最大标号index ,获得class_index 的string字符串的函数
# valid_generator.class_indices.items()是一个字典,有关于string和index的对应
def get_class_string_from_index(index):
for class_string, class_index in valid_generator.class_indices.items():
if class_index == index:
return class_string
x, y = next(valid_generator)
image = x[0, :, :, :]
true_index = np.argmax(y[0])
plt.imshow(image)
plt.axis('off')
plt.show()
# Expand the validation image to (1, 224, 224, 3) before predicting the label
prediction_scores = model.predict(np.expand_dims(image, axis=0))
predicted_index = np.argmax(prediction_scores)
print("True label: " + get_class_string_from_index(true_index))
print("Predicted label: " + get_class_string_from_index(predicted_index))
True label: daisy
Predicted label: daisy
最后,可以将训练后的模型保存起来,以部署到TF Serving或TF Lite(在移动设备上),如下所示。
saved_model_path = f"/tmp/saved_flowers_model_{model_name}"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/saved_flowers_model_mobilenet_v3_small_100_224/assets
INFO:tensorflow:Assets written to: /tmp/saved_flowers_model_mobilenet_v3_small_100_224/assets
可选:部署到TensorFlow Lite
TensorFlow Lite可让您将TensorFlow模型部署到移动和IoT设备。以下代码显示了如何将训练后的模型转换为TF Lite并应用TensorFlow Model Optimization Toolkit中的后训练工具。最后,它在TF Lite解释器中运行以检查最终的质量
- 无需优化即可进行转换,其结果与以前相同(最大舍入误差)。
- 在不进行任何数据的情况下进行优化转换会将模型权重量化为8位,但推理仍将浮点计算用于神经网络激活。这样可以将模型大小减少近四倍,并改善了移动设备上的CPU延迟。
- 最重要的是,如果提供了一个小的参考数据集来校准量化范围,则神经网络激活的计算也可以量化为8位整数。在移动设备上,这可以进一步加快推理速度,并使其可以在EdgeTPU等加速器上运行。
优化设置
optimize_lite_model = False
num_calibration_examples = 60
representative_dataset = None
if optimize_lite_model and num_calibration_examples:
# Use a bounded number of training examples without labels for calibration.
# TFLiteConverter expects a list of input tensors, each with batch size 1.
representative_dataset = lambda: itertools.islice(
([image[None, ...]] for batch, _ in train_generator for image in batch),
num_calibration_examples)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
if optimize_lite_model:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
if representative_dataset: # This is optional, see above.
converter.representative_dataset = representative_dataset
lite_model_content = converter.convert()
with open(f"/tmp/lite_flowers_model_{model_name}.tflite", "wb") as f:
f.write(lite_model_content)
print("Wrote %sTFLite model of %d bytes." %
("optimized " if optimize_lite_model else "", len(lite_model_content)))
Wrote TFLite model of 6097236 bytes.
interpreter = tf.lite.Interpreter(model_content=lite_model_content)
# This little helper wraps the TF Lite interpreter as a numpy-to-numpy function.
def lite_model(images):
interpreter.allocate_tensors()
interpreter.set_tensor(interpreter.get_input_details()[0]['index'], images)
interpreter.invoke()
return interpreter.get_tensor(interpreter.get_output_details()[0]['index'])
num_eval_examples = 50
eval_dataset = ((image, label) # TFLite expects batch size 1.
for batch in train_generator
for (image, label) in zip(*batch))
count = 0
count_lite_tf_agree = 0
count_lite_correct = 0
for image, label in eval_dataset:
probs_lite = lite_model(image[None, ...])[0]
probs_tf = model(image[None, ...]).numpy()[0]
y_lite = np.argmax(probs_lite)
y_tf = np.argmax(probs_tf)
y_true = np.argmax(label)
count +=1
if y_lite == y_tf: count_lite_tf_agree += 1
if y_lite == y_true: count_lite_correct += 1
if count >= num_eval_examples: break
print("TF Lite model agrees with original model on %d of %d examples (%g%%)." %
(count_lite_tf_agree, count, 100.0 * count_lite_tf_agree / count))
print("TF Lite model is accurate on %d of %d examples (%g%%)." %
(count_lite_correct, count, 100.0 * count_lite_correct / count))
TF Lite model agrees with original model on 50 of 50 examples (100%).
TF Lite model is accurate on 47 of 50 examples (94%).
总结
从两个字典
model_handle_map :获得model的hub地址
model_image_size_map:获得模型对应的输入image_size 的pixels
准备好IMAGE_SIZE= (pixels, pixels) ,BATCH_SIZE = 32
---
创建参数传递字典datagen_kwargs ,dataflow_kwargs
用**datagen_kwargs,**dataflow_kwargs传递进去
获取 数据 tf.keras.utils.get_file()
读取图片 tf.keras.preprocessing.image.ImageDataGenerator(**datagen_kwargs)
valid_datagen.flow_from_directory(**dataflow_kwargs)
图片生成器有属性num_class
train_generator.num_classes
steps_per_epoch = train_generator.samples // train_generator.batch_size
validation_steps = valid_generator.samples // valid_generator.batch_size
hist = model.fit( steps_per_epoch #每轮多少steps).history直接得到历史记录的list
保存模型
tf.saved_model.save(model,saved_model_path)