tensorflow_tflite专题

tensorflow_tflite专题
本文章主要包括两大问题:

tflite的转换:如何转换得到tflite?
tflite的测试:如何测试或者说如何在PC端使用tflite?

问题一:如何转换得到tflite

分为两个过程,步骤:cheakpoint→pb模型→tflite模型

  • step1:cheakpoint→tflite_graph.pb:
    使用object_detection的export_tflite_ssd_graph.py,结果生成tflite_graph.pb和tflite_graph.pbtxt两个文件

超参数:
"output_directory":输出的文件夹
"pipeline_config_path":网络配置文件
"trained_checkpoint_prefix":你的cheakpoint文件

  • step2:tflite_graph.pb→out_put.tflite:
    使用convert.py程序讲pb转换为tflite,这里的pb是上一步转换得到了,不能是其他来源的pb模型
import tensorflow as tf

# 需要配置
in_path = "tflite_graph.pb"

# 模型输入节点对于object_detection是固定的,不需改动,但是shape是和网络有关
input_tensor_name = ["normalized_input_image_tensor"]
input_tensor_shape = {"normalized_input_image_tensor":[1,256,256,3]}
# 模型输出节点,对于object_detection是固定的,不需改动
classes_tensor_name = ['TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1', 'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3']

converter = tf.lite.TFLiteConverter.from_frozen_graph(in_path,input_tensor_name, classes_tensor_name,input_tensor_shape)

converter.allow_custom_ops=True
converter.post_training_quantize = True
tflite_model = converter.convert()

open("output.tflite", "wb").write(tflite_model)

print("done")

问题二:如何测试或者说如何在PC端使用tflite?

这里给出代码:

import numpy as np
import tensorflow as tf
import cv2 #用来读取图片并进行预处理
import glob #读取某文件夹所有测试图片
import time #主要是用来计算推理花费时间

# Load TFLite model and allocate tensors.
model_path="output_fp16.tflite"  #tflite路径
interpreter = tf.lite.Interpreter(model_path)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print(input_details)
print(output_details) #在这里可以看到tflite的输入输出的节点信息

def detection(img_src):
    img = cv2.resize(img_src, (256, 256))
    img = img / 128 - 1
    input_data = np.expand_dims(img, 0)
    input_data = input_data.astype(np.float32)
    #以上是对图片经行尺寸变换、归一化、添加维度和类型转换,以便和输入节点对应

    index = input_details[0]['index']
    interpreter.set_tensor(index, input_data)
    interpreter.invoke() #启动
    
    output0 = interpreter.get_tensor(output_details[0]['index'])  # bbox
    output1 = interpreter.get_tensor(output_details[1]['index'])  # bbox
    output2 = interpreter.get_tensor(output_details[2]['index'])  # bbox
    output3 = interpreter.get_tensor(output_details[3]['index'])  # 概率
    #在这里你可以通过print查看4个输出的信息
    #分别时object_detection的信息:
    #对于我来讲,人脸检测不涉及类别,所以我只用到
    # output0:位置信息
    # output2:对应的概率
    
    #我只要概率最大的人脸,且概率>0.6保持,否则讲概率置为0
    output3=output3[0][0] if output3[0][0] > 0.6 else 0
    
    return bbox,output3 #返回概率信息和其位置信息

imgs_path = glob.glob('../../test_iamge/*')


for img_path in imgs_path:
    t1=time.time()
    img=cv2.imread(img_path)
    sp = img.shape
    bbox,confidence=detection(img)
    if confidence!=0:
        print('置信度=',confidence,'   bbox=',bbox,end='   ')
        y1 = int(bbox[0][0][0] * sp[0])
        x1 = int(bbox[0][0][1] * sp[1])
        y2 = int(bbox[0][0][2] * sp[0])
        x2 = int(bbox[0][0][3] * sp[1])

        cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 3)
        print('time=',time.time()-t1)
        cv2.namedWindow(str(confidence*100)[2:6]+'%', 0)
        cv2.imshow(str(confidence*100)[2:6]+'%', img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    else:
        print('time=',time.time()-t1)

上一篇:【C/C++学院】(3)二维数组/二分查找法/指针/模块注射


下一篇:(Deeplabv3+MobilenetV2)语义分割模型部署手机端(ckpt-pb-tflite)