2.2.1 Transfer learning - the concepts & coding
迁移学习就是把已经训练好的模型、参数,迁移至另外的一个新模型上使得我们不需要从零开始重新训练一个新model。
使用 Image-net.org 世界上图像识别最大的数据库
已经训练好的模型“快照”存放在这里:https://storage.googleapis.com/mledu-datasets/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5
1.下载模型"快照"
import ssl
import urllib
url = 'https://storage.googleapis.com/mledu-datasets/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5'
ssl._create_default_https_context = ssl._create_unverified_context
urllib.request.urlopen(url)
wget.download(url, out='inception_v3_weights_tf_dim_ordering_tf_kernels_notop')
可将参数加载到模型的骨架中,使之编程训练好的模型。keras有内置的模型定义,指定不需要的权重层。
Inception-V3: 在顶部具有全连接的层。include_top设置为false,将忽略全连接层,直接进入卷积层。
这里建议将文件名删除后面的 '.h5' 不然老报错
2. 加载模型"快照"
local_weights_file = 'inception_v3_weights_tf_dim_ordering_tf_kernels_notop'
pre_trained_model = InceptionV3(input_shape=(150, 150, 3),
include_top=False,
weights=None)
# 遍历所有层并锁定它们
pre_trained_model.load_weights(local_weights_file)
for layer in pre_trained_model.layers:
layer.trainable = False
# 打印预模型摘要
# pre_trained_model.summary()
最后一层已卷积到3x3
希望保留更多信息,所以将最底层卷积到7x7
last_layer = pre_trained_model.get_layer('mixed7')
print('last layer output shape: ', last_layer.output_shape)
last_output = last_layer.output
3. 编译
from tensorflow.keras.optimizers import RMSprop
# 将输出层扁平化到1维
x = layers.Flatten()(last_output)
# 增加一层1024的全连接层
x = layers.Dense(1024, activation='relu')(x)
# 添加 dropout 值 0.2,意味着图片中不重要的信息/特征就不参与计算
x = layers.Dropout(0.2)(x)
# 添加最后一层作为分类
x = layers.Dense (1, activation='sigmoid')(x)
model = Model(pre_trained_model.input, x)
model.compile( optimizer=RMSprop(lr=0.0001),
loss = 'binary_crossentropy',
metrics=['acc'])
4. 训练猫狗数据集
# -------------------------------------------------------- #
# 4. 训练、验证数据集
# -------------------------------------------------------- #
# import ssl
# import urllib
# url = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
# ssl._create_default_https_context = ssl._create_unverified_context
# urllib.request.urlopen(url)
# wget.download(url, out='./tmp/cats_and_dogs_filtered.zip')
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
import zipfile
local_zip = './tmp/cats_and_dogs_filtered.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('./temp')
zip_ref.close()
# 定义目录
base_dir = './tmp/cats_and_dogs_filtered'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
train_cat_dir = os.path.join(train_dir, 'cats')
train_dog_dir = os.path.join(train_dir, 'dogs')
validation_cat_dir = os.path.join(validation_dir, 'cats')
validation_dog_dir = os.path.join(validation_dir, 'dogs')
train_cat_fnames = os.listdir(train_cat_dir)
train_dog_fnames = os.listdir(train_dog_dir)
# 训练数据集图片生成器
train_datagen = ImageDataGenerator(
rescale=1.0/255.,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)
test_datagen = ImageDataGenerator(rescale=1.0/255.)
train_generator = train_datagen.flow_from_directory(train_dir,
batch_size=20,
class_mode='binary',
target_size=(150, 150))
validation_generator = test_datagen.flow_from_directory(validation_dir,
batch_size=20,
class_mode='binary',
target_size=(150, 150))
# -------------------------------------------------------- #
# 5. 训练
# -------------------------------------------------------- #
history = model.fit_generator(
train_generator,
validation_data=validation_generator,
steps_per_epoch=100,
epochs=20,
validation_steps=50,
verbose=2)
# -------------------------------------------------------- #
# 6. 显示结果
# -------------------------------------------------------- #
import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs=range(len(acc))
plt.plot(epochs, acc, 'r', label='Training accruracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.figure()
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.figure()
plt.show()