DL之VGG16:基于VGG16(Keras)利用Knifey-Spoony数据集对网络架构进行迁移学习(二)

设计思路

1、基模型

DL之VGG16:基于VGG16(Keras)利用Knifey-Spoony数据集对网络架构进行迁移学习(二)


2、思路导图

DL之VGG16:基于VGG16(Keras)利用Knifey-Spoony数据集对网络架构进行迁移学习(二)


核心代码

 

model_VGG16.summary()                              

transfer_layer = model_VGG16.get_layer('block5_pool')    

print('transfer_layer.output:', transfer_layer.output)  

conv_model = Model(inputs=model_VGG16.input,

                  outputs=transfer_layer.output)

VGG16_TL_model = Sequential()                        # Start a new Keras Sequential model.

VGG16_TL_model.add(conv_model)                       # Add the convolutional part of the VGG16 model from above.

VGG16_TL_model.add(Flatten())                        # Flatten the output of the VGG16 model because it is from a convolutional layer.

VGG16_TL_model.add(Dense(1024, activation='relu'))   # Add a dense (aka. fully-connected) layer. This is for combining features that the VGG16 model has recognized in the image.

VGG16_TL_model.add(Dropout(0.5))                     # Add a dropout-layer which may prevent overfitting and improve generalization ability to unseen data e.g. the test-set.

VGG16_TL_model.add(Dense(num_classes, activation='softmax'))  # Add the final layer for the actual classification.

print_layer_trainable()

conv_model.trainable = False

for layer in conv_model.layers:

   layer.trainable = False

print_layer_trainable()  

loss = 'categorical_crossentropy'  

metrics = ['categorical_accuracy']

VGG16_TL_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)      

epochs = 20

steps_per_epoch = 100

history = VGG16_TL_model.fit_generator(generator=generator_train,

                                 epochs=epochs,

                                 steps_per_epoch=steps_per_epoch,

                                 class_weight=class_weight,

                                 validation_data=generator_test,

                                 validation_steps=steps_test)

plot_training_history(history)  

VGG16_TL_model_result = VGG16_TL_model.evaluate_generator(generator_test, steps=steps_test)

print("Test-set classification accuracy: {0:.2%}".format(VGG16_TL_model_result[1]))


上一篇:数据结构五:树+堆(DataWhale系列)


下一篇:数据结构四:散列表+字符串(DataWhale系列)