tensorflow模型部署

TF Serving

工作流程

主要分为以下几个步骤:

  1. Source会针对需要进行加载的模型创建一个Loader,Loader中会包含要加载模型的全部信息;
  2. Source通知Manager有新的模型需要进行加载;
  3. Manager通过版本管理策略(Version Policy)来确定哪些模型需要被下架,哪些模型需要被加载;
  4. Manger在确认需要加载的模型符合加载策略,便通知Loader来加载最新的模型;
  5. 客户端像服务端请求模型结果时,可以指定模型的版本,也可以使用最新模型的结果;

安装

TF Serving官方文档:https://www.tensorflow.org/tfx/guide/serving

docker

docker pull tensorflow/serving

下载

git clone https://github.com/tensorflow/serving

运行

docker run -p 8501:8501 \
  --mount type=bind,\
   source=/tmp/tfserving/serving/tensorflow_serving/servables/tensorflow/testdata/saved_model_half_plus_two_cpu,\
target=/models/half_plus_two \
-e MODEL_NAME=half_plus_two -t tensorflow/serving &

验证

curl -d '{"instances": [1.0, 2.0, 5.0]}' \
  -X POST http://localhost:8501/v1/models/half_plus_two:predict 

部署前

模型导出

import tensorflow as tf
import shutil 

model = tf.keras.models.load_model('./cnn_model.h5')

# 指定路径
if os.path.exists('./Models/CNN/1'):
    shutil.rmtree('./Models/CNN/1')
    
export_path = './Models/CNN/1'

# 导出tensorflow模型以便部署
tf.saved_model.save(model,export_path)

检查和测试

saved_model_cli show --dir ./Models/CNN/1 --all

输入部分数据

saved_model_cli run --dir ./Models/CNN/1 --tag_set serve --signature_def serving_default --input_exp 'input_1=np.random.rand(1,100)'

部署

from tensorflow.keras.preprocessing import sequence
import random
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.utils import to_categorical
from utils import *
import json
import numpy
import requests
import jieba

# 路径等配置
data_dir = "./processed_data"
vocab_file = "./vocab/vocab.txt"
vocab_size = 40000

# 神经网络配置
max_features = 40001
maxlen = 100
batch_size = 256
embedding_dims = 50
epochs = 8

print('数据预处理与加载数据...')
# 如果不存在词汇表,重建
if not os.path.exists(vocab_file):  
    build_vocab(data_dir, vocab_file, vocab_size)
# 获得 词汇/类别 与id映射字典
categories, cat_to_id = read_category()
words, word_to_id = read_vocab(vocab_file)

text = "这是该国史上最大的一次军事演习"
text_seg = encode_sentences([jieba.lcut(text)], word_to_id)
text_input = sequence.pad_sequences(text_seg, maxlen=maxlen)

data = json.dumps({"signature_name": "serving_default",
                   "instances": text_input.reshape(1,100).tolist()})
headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8505/v1/models/default:predict',
                              data=data, headers=headers)
#print(json.loads(json_response.text))
print(json_response.text)
上一篇:【Java】红黑树的删除操作概述和代码实现


下一篇:LTE网络接口的类型及定义