转换成pb模型,设定多输出
def fun(): """ 保留bert第一层和第二层信息""" OUTPUT_GRAPH = 'pb_model/query_encoder.pb' output_node = ["loss/Softmax", "bert/pooler/dense/Tanh", "Mean"] ckpt_model = r'best_ckpt' bert_config_file = r'model/chinese_L-12_H-768_A-12/bert_config.json' max_seq_length = 10 gpu_config = tf.ConfigProto() gpu_config.gpu_options.allow_growth = True sess = tf.Session(config=gpu_config) graph = tf.get_default_graph() with graph.as_default(): print("going to restore checkpoint") input_ids_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_ids") input_mask_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_mask") bert_config = modeling.BertConfig.from_json_file(bert_config_file) (loss, per_example_loss, logits, probabilities, out) = create_model(bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p, segment_ids=None, labels=None, num_labels=len(label_list), use_one_hot_embeddings=False, fp16=FLAGS.use_fp16) saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(ckpt_model)) graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node) with tf.gfile.GFile(OUTPUT_GRAPH, "wb") as f: f.write(graph.SerializeToString()) print('extract vector pb model saved!')
推理部分
class BertEncoder(object): """ model """ def __init__(self, OUTPUT_GRAPH): self.max_length = 30 self.tokenizer = TOKENIZER self.out_graph = os.path.join(CURRENT_DIR, "pb_model", OUTPUT_GRAPH) self.model_graph = {} graph = tf.Graph() with graph.as_default(): self.model_graph['output_graph_def'] = tf.compat.v1.GraphDef() with open(self.out_graph, "rb") as f: self.model_graph['output_graph_def'].ParseFromString(f.read()) self.model_graph['sess'] = tf.Session(graph=graph) with self.model_graph['sess'].as_default(): with graph.as_default(): self.model_graph['sess'].run(tf.compat.v1.global_variables_initializer()) _input_1, _input2, _output_1, _output_2, _cls_out = tf.import_graph_def(self.model_graph['output_graph_def'], return_elements=[INPUT_1, INPUT_2, SOFTMAX_OUTPUT, FIRST_LAST_OUTPUT, CLS_OUTPUT]) self.input_ids_p = self.model_graph['sess'].graph.get_tensor_by_name("import/input_ids:0") self.input_mask_p = self.model_graph['sess'].graph.get_tensor_by_name("import/input_mask:0") self.output_1 = self.model_graph['sess'].graph.get_tensor_by_name("import/loss/Softmax:0") self.output_2 = self.model_graph['sess'].graph.get_tensor_by_name("import/Mean:0") self.output_3 = self.model_graph['sess'].graph.get_tensor_by_name("import/bert/pooler/dense/Tanh:0")