tensorflow从训练自定义CNN网络模型到Android端部署tflite

网上有很多关于tensorflow lite在安卓端部署的教程,但是大多只讲如何把训练好的模型部署到安卓端,不讲如何训练,而实际上在部署的时候,需要知道训练模型时预处理的细节,这就导致了自己训练的模型在部署到安卓端的时候出现各种问题。因此,本文会记录从PC端训练、导出到安卓端部署的各种细节。欢迎大家讨论、指教。

PC端系统:Ubuntu14

tensorflow版本:tensroflow1.14

安卓版本:9.0

PC端训练过程

数据集:自定义生成

训练框架:tensorflow slim  关于tensorflow slim如何安装,这里不再赘述,大家自行百度解决。

数据生成代码:生成50000张28*28大小三通道的验证码图片,共分10类,0-9,生成的数据保存在datasets/images/里面

 

# -*- coding: utf-8 -*-

import cv2
import numpy as np

from captcha.image import ImageCaptcha


def generate_captcha(text='1'):
    """Generate a digit image."""
    capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
    image = capt.generate_image(text)
    image = np.array(image, dtype=np.uint8)
    return image
    
    
if __name__ == '__main__':
    output_dir = './datasets/images/'
    for i in range(50000):
        label = np.random.randint(0, 10)
        image = generate_captcha(str(label))
        image_name = 'image{}_{}.jpg'.format(i+1, label)
        output_path = output_dir + image_name
        cv2.imwrite(output_path, image)

 

训练:本次训练我用tensorflow slim 搭建了一个七层卷积的网络,最后测试准确率在96%~99%左右,模型1.2M,适合在移动端部署。训练的时候我做了两点工作

1、指明了模型的输入和输出节点的名字,PC端部署测试模型的时候要用到,也便于快速确定模型的输出数据到底是什么格式,移动端代码要与其保持一致

 

inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 3], name='inputs')
.......
.......
prob_ = tf.identity(prob, name='prob')

2、训练结束的时候直接把模型保存成PB格式

 

        constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['inputs','prob']) #训练完毕直接把模型保存为PB格式
        with tf.gfile.FastGFile('model3.pb', mode='wb') as f: #模型的名字是model.pb
            f.write(constant_graph.SerializeToString())

 

训练代码如下

 

# -*- coding: utf-8 -*-

"""Train a CNN model to classifying 10 digits.

Example Usage:
---------------
python3 train.py \
    --images_path: Path to the training images (directory).
    --model_output_path: Path to model.ckpt.
"""

import cv2
import glob
import numpy as np
import os
import tensorflow as tf

import model
from tensorflow.python.framework import graph_util

flags = tf.app.flags

flags.DEFINE_string('images_path', None, 'Path to training images.')
flags.DEFINE_string('model_output_path', None, 'Path to model checkpoint.')
FLAGS = flags.FLAGS


def get_train_data(images_path):
    """Get the training images from images_path.
    
    Args: 
        images_path: Path to trianing images.
        
    Returns:
        images: A list of images.
        lables: A list of integers representing the classes of images.
        
    Raises:
        ValueError: If images_path is not exist.
    """
    if not os.path.exists(images_path):
        raise ValueError('images_path is not exist.')
        
    images = []
    labels = []
    images_path = os.path.join(images_path, '*.jpg')
    count = 0
    for image_file in glob.glob(images_path):
        count += 1
        if count % 100 == 0:
            print('Load {} images.'.format(count))
        image = cv2.imread(image_file)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # Assume the name of each image is imagexxx_label.jpg
        label = float(image_file.split('_')[-1].split('.')[0])
        images.append(image)
        labels.append(label)
    images = np.array(images)
    labels = np.array(labels)
    return images, labels


def next_batch_set(images, labels, batch_size=128):
    """Generate a batch training data.
    
    Args:
        images: A 4-D array representing the training images.
        labels: A 1-D array representing the classes of images.
        batch_size: An integer.
        
    Return:
        batch_images: A batch of images.
        batch_labels: A batch of labels.
    """
    indices = np.random.choice(len(images), batch_size)
    batch_images = images[indices]
    batch_labels = labels[indices]
    return batch_images, batch_labels


def main(_):
    inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 3], name='inputs')
    labels = tf.placeholder(tf.int32, shape=[None], name='labels')
    
    cls_model = model.Model(is_training=True, num_classes=10)
    preprocessed_inputs = cls_model.preprocess(inputs)#预处理
    prediction_dict = cls_model.predict(preprocessed_inputs)
    loss_dict = cls_model.loss(prediction_dict, labels)
    loss = loss_dict['loss']
    postprocessed_dict = cls_model.postprocess(prediction_dict)
    classes = postprocessed_dict['classes']
    prob = postprocessed_dict['prob']
    classes_ = tf.identity(classes, name='classes')
    prob_ = tf.identity(prob, name='prob')
    acc = tf.reduce_mean(tf.cast(tf.equal(classes, labels), 'float'))
    
    global_step = tf.Variable(0, trainable=False)
    learning_rate = tf.train.exponential_decay(0.05, global_step, 150, 0.9)
    
    optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9)
    train_step = optimizer.minimize(loss, global_step)
    
    saver = tf.train.Saver()
    
    images, targets = get_train_data(FLAGS.images_path)
    
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init)
        
        for i in range(6000):
            batch_images, batch_labels = next_batch_set(images, targets)
            train_dict = {inputs: batch_images, labels: batch_labels}
            
            sess.run(train_step, feed_dict=train_dict)
            
            loss_, acc_,prob__,classes__ = sess.run([loss, acc, prob_,classes_], feed_dict=train_dict)
            
            train_text = 'step: {}, loss: {}, acc: {},classes:{}'.format(
                i+1, loss_, acc_,classes__)
            print(train_text)
            print (prob__)
        saver.save(sess, FLAGS.model_output_path)
        constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['inputs','prob']) #训练完毕直接把模型保存为PB格式
        with tf.gfile.FastGFile('model3.pb', mode='wb') as f: #模型的名字是model.pb
            f.write(constant_graph.SerializeToString())    
if __name__ == '__main__':
    tf.app.run()

 

这里尤其要注意,训练的时候图片是否做过预处理,比如减去均值和除法归一化操作,因为移动端需要保持和训练时候一样的操作。我的在训练的时候,预处理工作中包含了减去均值和除法归一化,并且把这两个OP打包直接放进了模型里面,也就是说图片数据进入模型之后会先进行预处理然后再进行正式的卷积等系列操作。所以,移动端的数据不需要单独写预处理的代码。很多时候,导出模型的时候并没有把预处理操作打包进模型,所以移动端要单独写几行关于减去均值和归一化的代码,然后再把数据送到分类模型当中。

另外一种把ckpt模型导出为pb模型的方式,代码如下

import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph(input_checkpoint,output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
    # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    #input_node_names = "inputs"
    output_node_names = "inputs,classes"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph() # 获得默认的图
    input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图
 
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,# 等于:sess.graph_def
            output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
            f.write(output_graph_def.SerializeToString()) #序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
        # for op in graph.get_operations():
        #     print(op.name, op.values())
# 输入ckpt模型路径
input_checkpoint='model/model.ckpt'
# 输出pb模型的路径
out_pb_path="frozen_model.pb"
# 调用freeze_graph将ckpt转为pb
freeze_graph(input_checkpoint,out_pb_path)

把PB模型导出为tflite格式代码

import tensorflow as tf
#把pb文件路径改成自己的pb文件路径即可
path = "model2.pb"
 
#如果是不知道自己的模型的输入输出节点,建议用tensorboard做可视化查看计算图,计算图里有输入输出的节点名称
inputs = ["inputs"]
outputs = ["prob"]
#转换pb模型到tflite模型
converter = tf.lite.TFLiteConverter.from_frozen_graph(path, inputs, outputs)
#converter.post_training_quantize = True
tflite_model = converter.convert()
open("model3.tflite", "wb").write(tflite_model)

还有另外一种利用bazel把模型导出为tflite的办法

进入tensorflow源码目录,两步编译
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/lite/toco:toco
./bazel-bin/tensorflow/contrib/lite/toco/toco
--input_file=/media/bayes/69da5b29-ae56-4feb-93a1-2ce24323aa78/project/model2.pb
--output_file=/media/bayes/69da5b29-ae56-4feb-93a1-2ce24323aa78/project/model2.tflite
--input_format=TENSORFLOW_GRAPHDEF
--output_format=TFLITE
--inference_type=FLOAT
--input_shape=1,28,28,3
--input_array=inputs
--output_array=prob

PB模型测试模型准确率

 

# -*- coding: utf-8 -*-

"""Evaluate the trained CNN model.
Example Usage:
---------------
python3 infrence_pb.py \
    --frozen_graph_path: Path to model frozen graph.
"""

import numpy as np
import tensorflow as tf

from captcha.image import ImageCaptcha

flags = tf.app.flags
flags.DEFINE_string('frozen_graph_path', None, 'Path to model frozen graph.')
FLAGS = flags.FLAGS


def generate_captcha(text='1'):
    capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
    image = capt.generate_image(text)
    image = np.array(image, dtype=np.uint8)
    return image


def main(_):
    model_graph = tf.Graph()
    with model_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(FLAGS.frozen_graph_path, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    
    with model_graph.as_default():
        with tf.Session(graph=model_graph) as sess:
            inputs = model_graph.get_tensor_by_name('inputs:0')
            classes = model_graph.get_tensor_by_name('classes:0')
            prob = model_graph.get_tensor_by_name('prob:0')
            for i in range(10):
                label = np.random.randint(0, 10)
                image = generate_captcha(str(label))
                image = 
                image_np = np.expand_dims(image, axis=0)
                predicted_label,probs = sess.run([classes,prob], 
                                           feed_dict={inputs: image_np})
                print(predicted_label, ' vs ', label)
                print(probs)
            
            
if __name__ == '__main__':
    tf.app.run()

 

tflite格式测试模型准确率

 

# -*- coding:utf-8 -*-
import os
import cv2
import numpy as np
import time

import tensorflow as tf

test_image_dir = './test_images/'
#model_path = "./model/quantize_frozen_graph.tflite"
model_path = "./model3.tflite"

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
print(str(input_details))
output_details = interpreter.get_output_details()
print(str(output_details))

#with tf.Session( ) as sess:
if 1:
    file_list = os.listdir(test_image_dir)
    
    model_interpreter_time = 0
    start_time = time.time()
    # 遍历文件
    for file in file_list:
        print('=========================')
        full_path = os.path.join(test_image_dir, file)
        print('full_path:{}'.format(full_path))
        

        img = cv2.imread(full_path )
        res_img = cv2.resize(img,(28,28),interpolation=cv2.INTER_CUBIC) 
        # 变成长784的一维数据
        #new_img = res_img.reshape((784))
        new_img = np.array(res_img, dtype=np.uint8)
        # 增加一个维度,变为 [1, 784]
        image_np_expanded = np.expand_dims(new_img, axis=0)
        image_np_expanded = image_np_expanded.astype('float32') # 类型也要满足要求
        
        # 填装数据
        model_interpreter_start_time = time.time()
        interpreter.set_tensor(input_details[0]['index'], image_np_expanded)
        
        # 注意注意,我要调用模型了
        interpreter.invoke()
        output_data = interpreter.get_tensor(output_details[0]['index'])
        model_interpreter_time += time.time() - model_interpreter_start_time
        
        # 出来的结果去掉没用的维度
        result = np.squeeze(output_data)
        print('result:{}'.format(result))
        #print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded})))
        
        # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
        #print('result:{}'.format( (np.where(result==np.max(result)))[0][0]  ))
    used_time = time.time() - start_time
    print('used_time:{}'.format(used_time))
    print('model_interpreter_time:{}'.format(model_interpreter_time))

模型训练好以后,接下来要把模型部署到安卓端,其实这步很简单,只要替换安卓代码相应部分即可,安卓代码我会上传到CSDN,大家按需下载即可。那么主要留意更改哪些代码呢

#模型的输入大小
private int[] ddims = {1, 3, 28, 28};
#模型的名称
private static final String[] PADDLE_MODEL = {
"model3",
"mobilenet_quant_v1_224",
"mobilenet_v1_1.0_224",
"mobilenet_v2"
};

#标签的名称
BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open("cacheLabel1.txt")));
#模型输出的数据类型,在PC端可以清楚地看到
float[][] labelProbArray = new float[1][10];

#输入数据预处理工作是否已经包含在模型里面
//  imgData.putFloat(((((val >> 16) & 0xFF) - 128f) / 128f));
// imgData.putFloat(((((val >> 8) & 0xFF) - 128f) / 128f));
// imgData.putFloat((((val & 0xFF) - 128f) / 128f));
imgData.putFloat(((val >> 16) & 0xFF) );
imgData.putFloat(((val >> 8) & 0xFF) );
imgData.putFloat((val & 0xFF) );

留一张测试图片,大家可以拿去测试,正确结果应该是0.0,安卓代码地址是这里,CSDN下载请搜索 anquangan

tensorflow从训练自定义CNN网络模型到Android端部署tflite

查看PB模型节点代码

#coding:utf-8
 
import tensorflow as tf
from tensorflow.python.framework import graph_util
tf.reset_default_graph()  # 重置计算图
output_graph_path = 'model3.pb'
with tf.Session() as sess:
 
    tf.global_variables_initializer().run()
    output_graph_def = tf.GraphDef()
    # 获得默认的图
    graph = tf.get_default_graph()
    with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(output_graph_def, name="")
        # 得到当前图有几个操作节点
        print("%d ops in the final graph." % len(output_graph_def.node))
 
        tensor_name = [tensor.name for tensor in output_graph_def.node]
        print(tensor_name)
        print('---------------------------')
        # 在log_graph文件夹下生产日志文件,可以在tensorboard中可视化模型
        #summaryWriter = tf.summary.FileWriter('log_graph/', graph)
 
 
        for op in graph.get_operations():
            # print出tensor的name和值
            print(op.name, op.values())

 

上一篇:阿里云服务器配置不够可以升级CPU内存带宽磁盘均可升级配置


下一篇:TensorFlow Lite for Android示例