tensorflow 对Model检测点的操作、model.get_layer、从 checkpoint加载权重、set_weights、model层属性获取

文章目录

model层属性获取

在tensorflow中,要想获取层的输出的各种信息,可以先获取层对象,再通过层对象的属性获取层输出的其他特性.

获取model对应层的方法为:

get_layer(self, name=None, index=None):

函数功能:根据层的名称(这个名称具有唯一性)或者索引号检索model获取对应的层.

获取层输出的其他特性

  1. model.get_layer(index=0).output # 输出张量
  2. model.get_layer(index=0).output_shape #各自的形状
  3. model.get_layer(index=0).input # 输出张量
  4. model.get_layer(index=0).output_shape #各自的形状
  5. #该层有多个节点时(node_index为节点序号):
  6. layer.get_input_at(node_index)
  7. layer.get_output_at(node_index)
  8. layer.get_input_shape_at(node_index)
  9. layer.get_output_shape_at(node_index)
  10. model.get_layer(“word_embeddings”).set_weights(weights) #将权重加载到该层
  11. model.get_layer(“word_embeddings”).get_weights() #返回层的权重(numpy array)
  12. config = model.get_layer(“word_embeddings”).get_config() #保存该层的配置

检查点

保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。

这种在训练中保存模型,习惯上称之为保存检查点。

load_checkpoint

tf.train.load_checkpoint(ckpt_dir_or_file):

在ckpt_dir_or_file中找到的检查点返回’ CheckpointReader ’
如果’ ckpt_dir_or_file '解析为具有多个检查点的目录,则返回最新检查点的reader。

variables = tf.train.load_checkpoint(init_checkpoint)

从Checkpoint对象获取张量:

variables.get_tensor(“bert/embeddings/word_embeddings”)

bert 中load_checkpoint并且get_layer().set_weights操作

variables = tf.train.load_checkpoint(init_checkpoint)
# embedding weights
model._encoder_layer.get_layer("word_embeddings").set_weights([
        variables.get_tensor("bert/embeddings/word_embeddings")])
model._encoder_layer.get_layer("position_embeddings").set_weights([
        variables.get_tensor("bert/embeddings/position_embeddings")])
model._encoder_layer.get_layer("type_embeddings").set_weights([
        variables.get_tensor("bert/embeddings/token_type_embeddings")])

model._encoder_layer.get_layer("embeddings/layer_norm").set_weights([
        variables.get_tensor("bert/embeddings/LayerNorm/gamma"),
        variables.get_tensor("bert/embeddings/LayerNorm/beta")
])

model._encoder_layer.get_layer("embedding_projection").set_weights([
        variables.get_tensor("bert/encoder/embedding_hidden_mapping_in/kernel"),
        variables.get_tensor("bert/encoder/embedding_hidden_mapping_in/bias")
])
# multi attention weights

    for i in range(model._config['bert_config'].num_hidden_layers):
        model._encoder_layer.get_layer("transformer/layer_{}".format(i)).set_weights([
            tf.reshape(variables.get_tensor(
                "bert/encoder/layer_{}/attention/self/query/kernel".format(i)),
                [model.bert_config.hidden_size, model.bert_config.num_attention_heads, -1]),
            tf.reshape(
                variables.get_tensor("bert/encoder/layer_{}/attention/self/query/bias".format(i)),
                [model.bert_config.num_attention_heads, -1]),
            tf.reshape(variables.get_tensor(
                "bert/encoder/layer_{}/attention/self/key/kernel".format(i)),
                [model.bert_config.hidden_size, model.bert_config.num_attention_heads, -1]),
            tf.reshape(
                variables.get_tensor("bert/encoder/layer_{}/attention/self/key/bias".format(i)),
                [model.bert_config.num_attention_heads, -1]),
            tf.reshape(variables.get_tensor(
                "bert/encoder/layer_{}/attention/self/value/kernel".format(i)),
                [model.bert_config.hidden_size, model.bert_config.num_attention_heads, -1]),
            tf.reshape(
                variables.get_tensor("bert/encoder/layer_{}/attention/self/value/bias".format(i)),
                [model.bert_config.num_attention_heads, -1]),
            tf.reshape(variables.get_tensor(
                "bert/encoder/layer_{}/attention/output/dense/kernel".format(i)),
                [model.bert_config.num_attention_heads, -1, model.bert_config.hidden_size]),
            variables.get_tensor("bert/encoder/layer_{}/attention/output/dense/bias".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/attention/output/LayerNorm/gamma".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/attention/output/LayerNorm/beta".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/intermediate/dense/kernel".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/intermediate/dense/bias".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/output/dense/kernel".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/output/dense/bias".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/output/LayerNorm/gamma".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/output/LayerNorm/beta".format(i)),
        ])

    model._encoder_layer.get_layer("pooler_transform").set_weights([
        variables.get_tensor("bert/pooler/dense/kernel"),
        variables.get_tensor("bert/pooler/dense/bias"),
    ])

tf.train.list_variables(init_checkpoint)
#列出检查点中变量的检查点键和形状。
#bert 例子

init_vars = tf.train.list_variables(init_checkpoint)
    for name, shape in init_vars:
        if name.startswith("bert"):
            print(f"{name}, shape={shape}, *INIT FROM CKPT SUCCESS*")
import tensorflow as tf
import os
ckpt_directory = "/tmp/training_checkpoints/ckpt"
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(ckpt, ckpt_directory, max_to_keep=3)
train_and_checkpoint(model, manager)
tf.train.list_variables(manager.latest_checkpoint)

保存检测点

x_train,y_train,x_test,y_test=process_data()
model=mode(x_train,y_train,x_test,y_test)
checkpoint=tf.train.Checkpoint(A=model)    #保存model
checkpoint.save('./checkpoint/01.ckpt')   #在源文件夹建立一个checkpoint文件夹,保存的是文件目录加文件前缀
上一篇:机器翻译界的BERT:可快速得到任意机翻模型的mRASP


下一篇:自然语言处理笔记02 -- Bert模型解读和实战