keras读取预训练模型统一接口

import numpy as np
from tensorflow.keras.layers import Flatten, Dense
from keras.models import Model
import tensorflow.keras.applications as KerasModel

supported_model = np.array([
                    ['xception', 'Xception'],
                    ['vgg16', 'VGG16'], ['vgg19', 'VGG19'],
                    ['resnet', 'ResNet50'], ['resnet', 'ResNet101'], ['resnet', 'ResNet152'],
                    ['resnet_v2', 'ResNet50V2'], ['resnet_v2', 'ResNet101V2'], ['resnet_v2', 'ResNet152V2'],
                    ['resnext', 'ResNeXt50'], ['resnext', 'ResNet101'],
                    ['inception_v3', 'InceptionV3'],
                    ['inception_resnet_v2', 'InceptionResNetV2'],
                    ['mobilenet', 'MobileNet'], ['mobilenet_v2', 'MobileNetV2'],
                    ['densenet', 'DenseNet121'], ['densenet', 'DenseNet169'], ['densenet', 'DenseNet201'],
                    ['nasnet', 'NASNetLarge'],
                    ['nasnet', 'InceptionV3'], ['inception_v3', 'NASNetMobile']
                   ])

def loadModel(model_name:str, model_name_subclass:str,
              classes_sum:int, input_shape = (224, 224, 3), pretrained:bool = False,
              activation:str = 'logsoftmax'):
    if model_name not in supported_model[:, 0] or model_name_subclass not in supported_model[:, 1]:
        return None
    
    weights = 'imagenet' if pretrained else 'None'
    code = 'KerasModel.%s.%s(include_top = False, weights = %s, input_shape = %s)' \
        %(model_name, model_name_subclass, weights, str(input_shape))
    base_model = eval(code)
    if base_model == None:
        return None
    
    flatten = Flatten()
    out_layer = Dense(classes_sum, activation = activation)
    
    _input = base_model.input
    _output = out_layer(flatten(base_model.output))
    model = Model(_input, _output)
    return model

model = loadModel('resnet', 'ResNet50', 10)
上一篇:7.dom


下一篇:[LeetCode&Python] Problem 202. Happy Number