当我们把训练好的tensorflow训练图拿来进行预测时,会有多个训练时生成的节点,这些节点是不必要的,我们需要在预测的时候进行删除。
下面以bert的图为例,进行优化
def optimize_graph(self, checkpoint_file, model_config): import json tf = self.import_tf() from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True) init_checkpoint = checkpoint_file with tf.gfile.GFile(model_config, 'r') as f: bert_config = modeling.BertConfig.from_dict(json.load(f)) input_ids = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_ids') input_mask = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_mask') input_type_ids = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_type_ids') import contextlib jit_scope = contextlib.suppress with jit_scope(): input_tensors = [input_ids, input_mask, input_type_ids] model = modeling.BertModel( config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, token_type_ids=input_type_ids, use_one_hot_embeddings=False) tvars = tf.trainable_variables() (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) # get output tensor tf.train.init_from_checkpoint(init_checkpoint, assignment_map) reader = tf.train.NewCheckpointReader(init_checkpoint) output_weights = reader.get_tensor('output_weights') output_bias = reader.get_tensor('output_bias') output_layers = model.get_pooled_output() pooled = tf.nn.softmax(tf.nn.bias_add(tf.matmul(output_layers, output_weights, transpose_b=True), output_bias)) pooled = tf.identity(pooled, 'final_encodes') output_tensors = [pooled] tmp_g = tf.get_default_graph().as_graph_def() # write graph to file with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) tmp_g = tf.graph_util.convert_variables_to_constants(sess, tmp_g, [n.name[:-2] for n in output_tensors]) dtypes = [n.dtype for n in input_tensors] tmp_g = optimize_for_inference( tmp_g, [n.name[:-2] for n in input_tensors], [n.name[:-2] for n in output_tensors], [dtype.as_datatype_enum for dtype in dtypes], False) import tempfile tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=r'optimize').name with tf.gfile.GFile(tmp_file, 'wb') as f: f.write(tmp_g.SerializeToString()) return tmp_file
返回一个gfile类型的文件,我们可以像原来导入模型文件时,恢复图,不过这个图是优化过的。