Tensorflow从1.3版本开始引入了Estimator,并且随着版本的演进越来越加大了对这种高级API编程方式的支持,而且在Estimator上可以很方便的实现对多GPU训练的支持。在我之前的博客中,我都是使用的低级API来进行模型的构建和训练,这个好处是更加灵活,可以了解模型的底层的细节,但是缺点是代码量较大,也比较繁琐,很多细节都需要自己去实现。为此我尝试了在最新的TensorFlow1.14版本上来使用高级API,来测试一下是否真的便于使用,以及能达到和低级API同样的性能。
我是基于ImageNet的图像分类数据来进行测试,Imagenet的数据准备可以参见我之前的博客。具体的代码如下,里面包括了2个模型,一个是在Yolo V3中用到的预训练模型Darknet53,另一个是Alexnet。
import tensorflow as tf
import horovod.tensorflow as hvd
import os
import random
import time
import numpy as np
from absl import app as absl_app
imageWidth = 224
imageHeight = 224
imageDepth = 3
batch_size = 32
resize_min = 256
train_files_names = os.listdir('/data/AI/train_tf/')
train_files = ['/data/AI/train_tf/'+item for item in train_files_names]
valid_files_names = os.listdir('/data/AI/valid_tf/')
valid_files = ['/data/AI/valid_tf/'+item for item in valid_files_names]
# Parse TFRECORD and distort the image for train
def _parse_function(example_proto):
features = {"image": tf.FixedLenFeature([], tf.string, default_value=""),
"height": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
"width": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
"channels": tf.FixedLenFeature([1], tf.int64, default_value=[3]),
"colorspace": tf.FixedLenFeature([], tf.string, default_value=""),
"img_format": tf.FixedLenFeature([], tf.string, default_value=""),
"label": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
"bbox_xmin": tf.VarLenFeature(tf.float32),
"bbox_xmax": tf.VarLenFeature(tf.float32),
"bbox_ymin": tf.VarLenFeature(tf.float32),
"bbox_ymax": tf.VarLenFeature(tf.float32),
"text": tf.FixedLenFeature([], tf.string, default_value=""),
"filename": tf.FixedLenFeature([], tf.string, default_value="")
}
parsed_features = tf.parse_single_example(example_proto, features)
image_decoded = tf.image.decode_jpeg(parsed_features["image"], channels=3)
# Random resize the image
shape = tf.shape(image_decoded)
height, width = shape[0], shape[1]
resized_height, resized_width = tf.cond(height<width,
lambda: (resize_min, tf.cast(tf.multiply(tf.cast(width, tf.float64),tf.divide(resize_min,height)), tf.int32)),
lambda: (tf.cast(tf.multiply(tf.cast(height, tf.float64),tf.divide(resize_min,width)), tf.int32), resize_min))
image_float = tf.image.convert_image_dtype(image_decoded, tf.float32)
resized = tf.image.resize_images(image_float, [resized_height, resized_width])
# Random crop from the resized image
cropped = tf.random_crop(resized, [imageHeight, imageWidth, 3])
# Flip to add a little more random distortion in.
flipped = tf.image.random_flip_left_right(cropped)
# Standardization the image
#image_train = flipped
image_train = tf.image.per_image_standardization(flipped)
#features = {'images': image_train}
return image_train, tf.one_hot(parsed_features["label"][0], 1000)
def train_input_fn():
dataset_train = tf.data.TFRecordDataset(train_files)
dataset_train = dataset_train.map(_parse_function, num_parallel_calls=4)
dataset_train = dataset_train.repeat(10)
dataset_train = dataset_train.batch(batch_size)
dataset_train = dataset_train.prefetch(batch_size)
return dataset_train
#return tf.data.make_one_shot_iterator(dataset_train).get_next()
def _parse_test_function(example_proto):
features = {"image": tf.FixedLenFeature([], tf.string, default_value=""),
"height": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
"width": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
"channels": tf.FixedLenFeature([1], tf.int64, default_value=[3]),
"colorspace": tf.FixedLenFeature([], tf.string, default_value=""),
"img_format": tf.FixedLenFeature([], tf.string, default_value=""),
"label": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
"bbox_xmin": tf.VarLenFeature(tf.float32),
"bbox_xmax": tf.VarLenFeature(tf.float32),
"bbox_ymin": tf.VarLenFeature(tf.float32),
"bbox_ymax": tf.VarLenFeature(tf.float32),
"text": tf.FixedLenFeature([], tf.string, default_value=""),
"filename": tf.FixedLenFeature([], tf.string, default_value="")
}
parsed_features = tf.parse_single_example(example_proto, features)
image_decoded = tf.image.decode_jpeg(parsed_features["image"], channels=3)
shape = tf.shape(image_decoded)
height, width = shape[0], shape[1]
resized_height, resized_width = tf.cond(height<width,
lambda: (resize_min, tf.cast(tf.multiply(tf.cast(width, tf.float64),tf.divide(resize_min,height)), tf.int32)),
lambda: (tf.cast(tf.multiply(tf.cast(height, tf.float64),tf.divide(resize_min,width)), tf.int32), resize_min))
image_float = tf.image.convert_image_dtype(image_decoded, tf.float32)
image_resized = tf.image.resize_images(image_float, [resized_height, resized_width])
# calculate how many to be center crop
shape = tf.shape(image_resized)
height, width = shape[0], shape[1]
amount_to_be_cropped_h = (height - imageHeight)
crop_top = amount_to_be_cropped_h // 2
amount_to_be_cropped_w = (width - imageWidth)
crop_left = amount_to_be_cropped_w // 2
image_cropped = tf.slice(image_resized, [crop_top, crop_left, 0], [imageHeight, imageWidth, -1])
image_valid = tf.image.per_image_standardization(image_cropped)
#features = {'images': image_valid}
return image_valid, tf.one_hot(parsed_features["label"][0], 1000)
def val_input_fn():
dataset_valid = tf.data.TFRecordDataset(valid_files)
dataset_valid = dataset_valid.map(_parse_test_function, num_parallel_calls=4)
dataset_valid = dataset_valid.batch(batch_size)
dataset_valid = dataset_valid.prefetch(batch_size)
return dataset_valid
#return tf.data.make_one_shot_iterator(dataset_valid).get_next()
def darknet_53():
image = tf.keras.Input(shape=(imageHeight,imageWidth,3))
l = tf.keras.layers
def _conv(inputs, filters, kernel_size, strides, padding, bias=False, normalize=True, activation='leaky_relu'):
output = inputs
padding_str = 'same'
if padding>0:
output = l.ZeroPadding2D(padding=(padding, padding))(output)
padding_str = 'valid'
output = l.Conv2D(filters, kernel_size, strides, padding_str, use_bias=bias, \
kernel_initializer='he_normal', \
kernel_regularizer=tf.keras.regularizers.l2(l=5e-4))(output)
if normalize:
output = l.BatchNormalization(axis=3)(output, training=True)
if activation=='leaky_relu':
output = l.LeakyReLU(alpha=0.1)(output)
return output
def _residual(inputs, filters):
output = _conv(inputs, filters, 1, (1,1), 0)
output = _conv(output, filters*2, 3, (1,1), 1)
output = tf.add(inputs, output)
return output
net = _conv(image, 32, 3, (1,1), 1)
net = _conv(net, 64, 3, (2,2), 1)
net = _residual(net, 32)
net = _conv(net, 128, 3, (2,2), 1)
for _ in range(2):
net = _residual(net, 64)
net = _conv(net, 256, 3, (2,2), 1)
for _ in range(8):
net = _residual(net, 128)
#add route1
net = _conv(net, 512, 3, (2,2), 1)
for _ in range(8):
net = _residual(net, 256)
#add route2
net = _conv(net, 1024, 3, (2,2), 1)
for _ in range(4):
net = _residual(net, 512)
#add route3
net = l.GlobalAveragePooling2D()(net)
net = l.Dense(1000, kernel_initializer=tf.initializers.truncated_normal(stddev=1e-1))(net)
net = tf.keras.activations.softmax(net)
model = tf.keras.Model(inputs=image, outputs=net)
return model
def alexnet():
image = tf.keras.Input(shape=(imageHeight,imageWidth,3))
l = tf.keras.layers
def _conv(inputs, filters, kernel_size, strides, padding, bias=True):
output = inputs
padding_str = 'same'
if padding>0:
output = l.ZeroPadding2D(padding=(padding, padding))(output)
padding_str = 'valid'
output = l.Conv2D(filters, kernel_size, strides, padding_str, use_bias=bias, \
kernel_initializer=tf.initializers.truncated_normal(stddev=1e-1), \
kernel_regularizer=tf.keras.regularizers.l2(l=5e-4))(output)
output = l.BatchNormalization(axis=3)(output, training=True)
output = l.ReLU()(output)
return output
net = _conv(image, 96, 11, 4, 0)
net = l.MaxPool2D(3, 2)(net)
net = _conv(net, 256, 5, 1, 0)
net = l.MaxPool2D(3, 2)(net)
net = _conv(net, 384, 3, 1, 0)
net = _conv(net, 384, 3, 1, 0)
net = _conv(net, 256, 3, 1, 0)
net = l.MaxPool2D(3, 2)(net)
net = l.Flatten()(net)
net = l.Dense(4096, kernel_initializer=tf.initializers.truncated_normal(stddev=1/4096))(net)
net = l.Dense(4096, kernel_initializer=tf.initializers.truncated_normal(stddev=1/4096))(net)
net = l.Dense(1000, kernel_initializer=tf.initializers.truncated_normal(stddev=1/1000))(net)
net = tf.keras.activations.softmax(net)
model = tf.keras.Model(inputs=image, outputs=net)
return model
def my_loss(y_true, y_pred):
l2_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
crossentropy = tf.keras.losses.categorical_crossentropy(y_true=y_true, y_pred=y_pred)
loss = tf.add(l2_loss, crossentropy)/batch_size
return loss
def main(_):
epoch_steps = 1281167/batch_size
boundaries = [epoch_steps*5, epoch_steps*8]
values = [0.01, 0.001, 0.0001]
learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries, values)
model = darknet_53()
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate = learning_rate_fn, momentum=0.9), \
loss='categorical_crossentropy', \
metrics=['categorical_accuracy', 'top_k_categorical_accuracy'])
est_model = tf.keras.estimator.model_to_estimator(model, model_dir='imagenet_model_darknet53/')
for _ in range(5):
est_model.train(input_fn=train_input_fn, steps=5000)
eval_results = est_model.evaluate(input_fn=val_input_fn)
print('\nEvaluation:\n\t%s\n' % eval_results)
if __name__ == "__main__":
tf.app.run(main)
从以上代码可以见到,采用Estimator还是能很方便的构建模型和进行训练的。不过在测试过程中我发现,如果我把输入的数据中Labels不进行Onehot编码,即用0-999的数字来表示图像类别,并且在Loss中选择'sparse_categorical_crossentropy',这样在训练中似乎无法有效的学习。此外在Keras的Conv2D中指定了kernel_regularizer后,似乎Loss就自动把regularizer加上了。还有如果使用Keras的BatchNormalization之后,要指定Training=TRUE。