Tensorflow Estimator实现ImageNet的图像识别

Tensorflow从1.3版本开始引入了Estimator,并且随着版本的演进越来越加大了对这种高级API编程方式的支持,而且在Estimator上可以很方便的实现对多GPU训练的支持。在我之前的博客中,我都是使用的低级API来进行模型的构建和训练,这个好处是更加灵活,可以了解模型的底层的细节,但是缺点是代码量较大,也比较繁琐,很多细节都需要自己去实现。为此我尝试了在最新的TensorFlow1.14版本上来使用高级API,来测试一下是否真的便于使用,以及能达到和低级API同样的性能。

我是基于ImageNet的图像分类数据来进行测试,Imagenet的数据准备可以参见我之前的博客。具体的代码如下,里面包括了2个模型,一个是在Yolo V3中用到的预训练模型Darknet53,另一个是Alexnet。

import tensorflow as tf
import horovod.tensorflow as hvd
import os
import random
import time
import numpy as np
from absl import app as absl_app

imageWidth = 224
imageHeight = 224
imageDepth = 3
batch_size = 32
resize_min = 256

train_files_names = os.listdir('/data/AI/train_tf/')
train_files = ['/data/AI/train_tf/'+item for item in train_files_names]
valid_files_names = os.listdir('/data/AI/valid_tf/')
valid_files = ['/data/AI/valid_tf/'+item for item in valid_files_names]

# Parse TFRECORD and distort the image for train
def _parse_function(example_proto):
    features = {"image": tf.FixedLenFeature([], tf.string, default_value=""),
                "height": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "width": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "channels": tf.FixedLenFeature([1], tf.int64, default_value=[3]),
                "colorspace": tf.FixedLenFeature([], tf.string, default_value=""),
                "img_format": tf.FixedLenFeature([], tf.string, default_value=""),
                "label": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "bbox_xmin": tf.VarLenFeature(tf.float32),
                "bbox_xmax": tf.VarLenFeature(tf.float32),
                "bbox_ymin": tf.VarLenFeature(tf.float32),
                "bbox_ymax": tf.VarLenFeature(tf.float32),
                "text": tf.FixedLenFeature([], tf.string, default_value=""),
                "filename": tf.FixedLenFeature([], tf.string, default_value="")
               }
    parsed_features = tf.parse_single_example(example_proto, features)
    image_decoded = tf.image.decode_jpeg(parsed_features["image"], channels=3)
    # Random resize the image 
    shape = tf.shape(image_decoded)
    height, width = shape[0], shape[1]
    resized_height, resized_width = tf.cond(height<width,
        lambda: (resize_min, tf.cast(tf.multiply(tf.cast(width, tf.float64),tf.divide(resize_min,height)), tf.int32)),
        lambda: (tf.cast(tf.multiply(tf.cast(height, tf.float64),tf.divide(resize_min,width)), tf.int32), resize_min))
    image_float = tf.image.convert_image_dtype(image_decoded, tf.float32)
    resized = tf.image.resize_images(image_float, [resized_height, resized_width])
    # Random crop from the resized image
    cropped = tf.random_crop(resized, [imageHeight, imageWidth, 3])
    # Flip to add a little more random distortion in.
    flipped = tf.image.random_flip_left_right(cropped)
    # Standardization the image
    #image_train = flipped
    image_train = tf.image.per_image_standardization(flipped)
    #features = {'images': image_train}
    return image_train, tf.one_hot(parsed_features["label"][0], 1000)

def train_input_fn():
    dataset_train = tf.data.TFRecordDataset(train_files)
    dataset_train = dataset_train.map(_parse_function, num_parallel_calls=4)
    dataset_train = dataset_train.repeat(10)
    dataset_train = dataset_train.batch(batch_size)
    dataset_train = dataset_train.prefetch(batch_size)
    return dataset_train
    #return tf.data.make_one_shot_iterator(dataset_train).get_next()

def _parse_test_function(example_proto):
    features = {"image": tf.FixedLenFeature([], tf.string, default_value=""),
                "height": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "width": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "channels": tf.FixedLenFeature([1], tf.int64, default_value=[3]),
                "colorspace": tf.FixedLenFeature([], tf.string, default_value=""),
                "img_format": tf.FixedLenFeature([], tf.string, default_value=""),
                "label": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "bbox_xmin": tf.VarLenFeature(tf.float32),
                "bbox_xmax": tf.VarLenFeature(tf.float32),
                "bbox_ymin": tf.VarLenFeature(tf.float32),
                "bbox_ymax": tf.VarLenFeature(tf.float32),
                "text": tf.FixedLenFeature([], tf.string, default_value=""),
                "filename": tf.FixedLenFeature([], tf.string, default_value="")
               }
    parsed_features = tf.parse_single_example(example_proto, features)
    image_decoded = tf.image.decode_jpeg(parsed_features["image"], channels=3)
    shape = tf.shape(image_decoded)
    height, width = shape[0], shape[1]
    resized_height, resized_width = tf.cond(height<width,
        lambda: (resize_min, tf.cast(tf.multiply(tf.cast(width, tf.float64),tf.divide(resize_min,height)), tf.int32)),
        lambda: (tf.cast(tf.multiply(tf.cast(height, tf.float64),tf.divide(resize_min,width)), tf.int32), resize_min))
    image_float = tf.image.convert_image_dtype(image_decoded, tf.float32)
    image_resized = tf.image.resize_images(image_float, [resized_height, resized_width])
    
    # calculate how many to be center crop
    shape = tf.shape(image_resized)  
    height, width = shape[0], shape[1]
    amount_to_be_cropped_h = (height - imageHeight)
    crop_top = amount_to_be_cropped_h // 2
    amount_to_be_cropped_w = (width - imageWidth)
    crop_left = amount_to_be_cropped_w // 2
    image_cropped = tf.slice(image_resized, [crop_top, crop_left, 0], [imageHeight, imageWidth, -1])
    image_valid = tf.image.per_image_standardization(image_cropped)
    #features = {'images': image_valid}
    return image_valid, tf.one_hot(parsed_features["label"][0], 1000)

def val_input_fn():
    dataset_valid = tf.data.TFRecordDataset(valid_files)
    dataset_valid = dataset_valid.map(_parse_test_function, num_parallel_calls=4)
    dataset_valid = dataset_valid.batch(batch_size)
    dataset_valid = dataset_valid.prefetch(batch_size)
    return dataset_valid
    #return tf.data.make_one_shot_iterator(dataset_valid).get_next()

def darknet_53():
    image = tf.keras.Input(shape=(imageHeight,imageWidth,3))
    l = tf.keras.layers
    def _conv(inputs, filters, kernel_size, strides, padding, bias=False, normalize=True, activation='leaky_relu'):
        output = inputs
        padding_str = 'same'
        if padding>0:
            output = l.ZeroPadding2D(padding=(padding, padding))(output)
            padding_str = 'valid'
        output = l.Conv2D(filters, kernel_size, strides, padding_str, use_bias=bias, \
                          kernel_initializer='he_normal', \
                          kernel_regularizer=tf.keras.regularizers.l2(l=5e-4))(output)
        if normalize:
            output = l.BatchNormalization(axis=3)(output, training=True)
        if activation=='leaky_relu':
            output = l.LeakyReLU(alpha=0.1)(output)
        return output

    def _residual(inputs, filters):
        output = _conv(inputs, filters, 1, (1,1), 0)
        output = _conv(output, filters*2, 3, (1,1), 1)
        output = tf.add(inputs, output)
        return output

    net = _conv(image, 32, 3, (1,1), 1)
    net = _conv(net, 64, 3, (2,2), 1)
    net = _residual(net, 32)
    net = _conv(net, 128, 3, (2,2), 1)
    for _ in range(2):
        net = _residual(net, 64)
    net = _conv(net, 256, 3, (2,2), 1)
    for _ in range(8):
        net = _residual(net, 128)
    #add route1
    net = _conv(net, 512, 3, (2,2), 1)
    for _ in range(8):
        net = _residual(net, 256)
    #add route2
    net = _conv(net, 1024, 3, (2,2), 1)
    for _ in range(4):
        net = _residual(net, 512)
    #add route3
    net = l.GlobalAveragePooling2D()(net)
    net = l.Dense(1000, kernel_initializer=tf.initializers.truncated_normal(stddev=1e-1))(net)
    net = tf.keras.activations.softmax(net)
    model = tf.keras.Model(inputs=image, outputs=net)
    
    return model

def alexnet():
    image = tf.keras.Input(shape=(imageHeight,imageWidth,3))
    l = tf.keras.layers
    def _conv(inputs, filters, kernel_size, strides, padding, bias=True):
        output = inputs
        padding_str = 'same'
        if padding>0:
            output = l.ZeroPadding2D(padding=(padding, padding))(output)
            padding_str = 'valid'
        output = l.Conv2D(filters, kernel_size, strides, padding_str, use_bias=bias, \
                          kernel_initializer=tf.initializers.truncated_normal(stddev=1e-1), \
                          kernel_regularizer=tf.keras.regularizers.l2(l=5e-4))(output)
        output = l.BatchNormalization(axis=3)(output, training=True)
        output = l.ReLU()(output)
        return output
    net = _conv(image, 96, 11, 4, 0)
    net = l.MaxPool2D(3, 2)(net)
    net = _conv(net, 256, 5, 1, 0)
    net = l.MaxPool2D(3, 2)(net)
    net = _conv(net, 384, 3, 1, 0)
    net = _conv(net, 384, 3, 1, 0)
    net = _conv(net, 256, 3, 1, 0)
    net = l.MaxPool2D(3, 2)(net)
    net = l.Flatten()(net)
    net = l.Dense(4096, kernel_initializer=tf.initializers.truncated_normal(stddev=1/4096))(net)
    net = l.Dense(4096, kernel_initializer=tf.initializers.truncated_normal(stddev=1/4096))(net)
    net = l.Dense(1000, kernel_initializer=tf.initializers.truncated_normal(stddev=1/1000))(net)
    net = tf.keras.activations.softmax(net)
    model = tf.keras.Model(inputs=image, outputs=net)
    
    return model
    
def my_loss(y_true, y_pred):
    l2_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    crossentropy = tf.keras.losses.categorical_crossentropy(y_true=y_true, y_pred=y_pred)
    loss = tf.add(l2_loss, crossentropy)/batch_size
    return loss

def main(_):
    epoch_steps = 1281167/batch_size
    boundaries = [epoch_steps*5, epoch_steps*8]
    values = [0.01, 0.001, 0.0001]
    learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries, values)

    model = darknet_53()
    model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate = learning_rate_fn, momentum=0.9), \
                  loss='categorical_crossentropy', \
                  metrics=['categorical_accuracy', 'top_k_categorical_accuracy'])
          
    est_model = tf.keras.estimator.model_to_estimator(model, model_dir='imagenet_model_darknet53/')
    for _ in range(5):
        est_model.train(input_fn=train_input_fn, steps=5000)
        eval_results = est_model.evaluate(input_fn=val_input_fn)
        print('\nEvaluation:\n\t%s\n' % eval_results)
    
if __name__ == "__main__":
  tf.app.run(main)

从以上代码可以见到,采用Estimator还是能很方便的构建模型和进行训练的。不过在测试过程中我发现,如果我把输入的数据中Labels不进行Onehot编码,即用0-999的数字来表示图像类别,并且在Loss中选择'sparse_categorical_crossentropy',这样在训练中似乎无法有效的学习。此外在Keras的Conv2D中指定了kernel_regularizer后,似乎Loss就自动把regularizer加上了。还有如果使用Keras的BatchNormalization之后,要指定Training=TRUE。

上一篇:Tensorflow Keras模型和Estimator有什么区别?


下一篇:机器学习算法原理与实践-正规方程、梯度下降(文章迁移)