TF Serving
工作流程
主要分为以下几个步骤:
- Source会针对需要进行加载的模型创建一个Loader,Loader中会包含要加载模型的全部信息;
- Source通知Manager有新的模型需要进行加载;
- Manager通过版本管理策略(Version Policy)来确定哪些模型需要被下架,哪些模型需要被加载;
- Manger在确认需要加载的模型符合加载策略,便通知Loader来加载最新的模型;
- 客户端像服务端请求模型结果时,可以指定模型的版本,也可以使用最新模型的结果;
安装
TF Serving官方文档:https://www.tensorflow.org/tfx/guide/serving
docker
docker pull tensorflow/serving
下载
git clone https://github.com/tensorflow/serving
运行
docker run -p 8501:8501 \
--mount type=bind,\
source=/tmp/tfserving/serving/tensorflow_serving/servables/tensorflow/testdata/saved_model_half_plus_two_cpu,\
target=/models/half_plus_two \
-e MODEL_NAME=half_plus_two -t tensorflow/serving &
验证
curl -d '{"instances": [1.0, 2.0, 5.0]}' \
-X POST http://localhost:8501/v1/models/half_plus_two:predict
部署前
模型导出
import tensorflow as tf
import shutil
model = tf.keras.models.load_model('./cnn_model.h5')
# 指定路径
if os.path.exists('./Models/CNN/1'):
shutil.rmtree('./Models/CNN/1')
export_path = './Models/CNN/1'
# 导出tensorflow模型以便部署
tf.saved_model.save(model,export_path)
检查和测试
saved_model_cli show --dir ./Models/CNN/1 --all
输入部分数据
saved_model_cli run --dir ./Models/CNN/1 --tag_set serve --signature_def serving_default --input_exp 'input_1=np.random.rand(1,100)'
部署
from tensorflow.keras.preprocessing import sequence
import random
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.utils import to_categorical
from utils import *
import json
import numpy
import requests
import jieba
# 路径等配置
data_dir = "./processed_data"
vocab_file = "./vocab/vocab.txt"
vocab_size = 40000
# 神经网络配置
max_features = 40001
maxlen = 100
batch_size = 256
embedding_dims = 50
epochs = 8
print('数据预处理与加载数据...')
# 如果不存在词汇表,重建
if not os.path.exists(vocab_file):
build_vocab(data_dir, vocab_file, vocab_size)
# 获得 词汇/类别 与id映射字典
categories, cat_to_id = read_category()
words, word_to_id = read_vocab(vocab_file)
text = "这是该国史上最大的一次军事演习"
text_seg = encode_sentences([jieba.lcut(text)], word_to_id)
text_input = sequence.pad_sequences(text_seg, maxlen=maxlen)
data = json.dumps({"signature_name": "serving_default",
"instances": text_input.reshape(1,100).tolist()})
headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8505/v1/models/default:predict',
data=data, headers=headers)
#print(json.loads(json_response.text))
print(json_response.text)