基于tensor2tensor的注意力可视化

根据训练好的Transformer模型,得到注意力矩阵,并对注意力进行可视化

首先安装:tensorflow 1.13.1 + tensor2tensor 1.13.1

 

# coding=utf-8
# Copyright 2020 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Shared code for visualizing transformer attentions."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np

# To register the hparams set
from tensor2tensor import models  # pylint: disable=unused-import
from tensor2tensor import problems
from tensor2tensor.utils import registry
from tensor2tensor.utils import trainer_lib

import tensorflow.compat.v1 as tf
from tensor2tensor.utils import usr_dir
EOS_ID = 1

class AttentionVisualizer2(object):
  """Helper object for creating Attention visualizations."""

  def __init__(
      self, hparams_set,hparams,t2t_usr_dir, model_name, data_dir, problem_name, beam_size=1):
    inputs, targets, samples, att_mats = build_model(
        hparams_set,hparams, t2t_usr_dir, model_name, data_dir, problem_name, beam_size=beam_size)

    # Fetch the problem
    ende_problem = problems.problem(problem_name)
    encoders = ende_problem.feature_encoders(data_dir)

    self.inputs = inputs
    self.targets = targets
    self.att_mats = att_mats
    self.samples = samples
    self.encoders = encoders

  def encode(self, input_str):
    """Input str to features dict, ready for inference."""
    inputs = self.encoders["inputs"].encode(input_str) + [EOS_ID]
    batch_inputs = np.reshape(inputs, [1, -1, 1, 1])  # Make it 3D.
    return batch_inputs

  def decode(self, integers):
    """List of ints to str."""
    integers = list(np.squeeze(integers))
    return self.encoders["targets"].decode(integers)

  def encode_list(self, integers):
    """List of ints to list of str."""
    integers = list(np.squeeze(integers))
    return self.encoders["inputs"].decode_list(integers)

  def decode_list(self, integers):
    """List of ints to list of str."""
    integers = list(np.squeeze(integers))
    return self.encoders["targets"].decode_list(integers)

  def get_vis_data_from_string(self, sess, input_string):
    """Constructs the data needed for visualizing attentions.
    Args:
      sess: A tf.Session object.
      input_string: The input sentence to be translated and visualized.
    Returns:
      Tuple of (
          output_string: The translated sentence.
          input_list: Tokenized input sentence.
          output_list: Tokenized translation.
          att_mats: Tuple of attention matrices; (
              enc_atts: Encoder self attention weights.
                A list of `num_layers` numpy arrays of size
                (batch_size, num_heads, inp_len, inp_len)
              dec_atts: Decoder self attention weights.
                A list of `num_layers` numpy arrays of size
                (batch_size, num_heads, out_len, out_len)
              encdec_atts: Encoder-Decoder attention weights.
                A list of `num_layers` numpy arrays of size
                (batch_size, num_heads, out_len, inp_len)
          )
    """
    encoded_inputs = self.encode(input_string)

    # Run inference graph to get the translation.
    out = sess.run(self.samples, {
        self.inputs: encoded_inputs,
    })



    # Run the decoded translation through the training graph to get the
    # attention tensors.


    att_mats = sess.run(self.att_mats, {
        self.inputs: encoded_inputs,
        self.targets: np.reshape(out, [1, -1, 1, 1]),
    })

    output_string = self.decode(out)
    input_list = self.encode_list(encoded_inputs)
    output_list = self.decode_list(out)

    return output_string, input_list, output_list, att_mats


def build_model(hparams_set, hparams,t2t_usr_dir, model_name, data_dir, problem_name, beam_size=1):
  """Build the graph required to fetch the attention weights.
  Args:
    hparams_set: HParams set to build the model with.
    model_name: Name of model.
    data_dir: Path to directory containing training data.
    problem_name: Name of problem.
    beam_size: (Optional) Number of beams to use when decoding a translation.
        If set to 1 (default) then greedy decoding is used.
  Returns:
    Tuple of (
        inputs: Input placeholder to feed in ids to be translated.
        targets: Targets placeholder to feed to translation when fetching
            attention weights.
        samples: Tensor representing the ids of the translation.
        att_mats: Tensors representing the attention weights.
    )
  """
  print(model_name)
  usr_dir.import_usr_dir(t2t_usr_dir)
  hparams = trainer_lib.create_hparams(
      hparams_set,hparams, data_dir=data_dir, problem_name=problem_name)

  # print(hparams)

  translate_model = registry.model(model_name)(
      hparams, tf.estimator.ModeKeys.EVAL)

  inputs = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="inputs")
  targets = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="targets")
  translate_model({
      "inputs": inputs,
      "targets": targets,
  })

  # Must be called after building the training graph, so that the dict will
  # have been filled with the attention tensors. BUT before creating the
  # inference graph otherwise the dict will be filled with tensors from
  # inside a tf.while_loop from decoding and are marked unfetchable.
  atts = get_att_mats(translate_model,model_name)

  with tf.variable_scope(tf.get_variable_scope(), reuse=True):
    samples = translate_model.infer({
        "inputs": inputs,
    }, beam_size=beam_size)["outputs"]

  return inputs, targets, samples, atts


def get_att_mats(translate_model,model_name):
  """Get's the tensors representing the attentions from a build model.
  The attentions are stored in a dict on the Transformer object while building
  the graph.
  Args:
    translate_model: Transformer object to fetch the attention weights from.
  Returns:
  Tuple of attention matrices; (
      enc_atts: Encoder self attention weights.
        A list of `num_layers` numpy arrays of size
        (batch_size, num_heads, inp_len, inp_len)
      dec_atts: Decoder self attetnion weights.
        A list of `num_layers` numpy arrays of size
        (batch_size, num_heads, out_len, out_len)
      encdec_atts: Encoder-Decoder attention weights.
        A list of `num_layers` numpy arrays of size
        (batch_size, num_heads, out_len, inp_len)
  )
  """
  enc_atts = []
  dec_atts = []
  encdec_atts = []

  prefix = "%s/body/"%(model_name)
  postfix_self_attention = "/multihead_attention/dot_product_attention"
  if translate_model.hparams.self_attention_type == "dot_product_relative":
    postfix_self_attention = ("/multihead_attention/"
                              "dot_product_attention_relative")
  postfix_encdec = "/multihead_attention/dot_product_attention"

  for i in range(translate_model.hparams.num_hidden_layers):
    enc_att = translate_model.attention_weights[
        "%sencoder/layer_%i/self_attention%s"
        % (prefix, i, postfix_self_attention)]
    dec_att = translate_model.attention_weights[
        "%sdecoder/layer_%i/self_attention%s"
        % (prefix, i, postfix_self_attention)]
    encdec_att = translate_model.attention_weights[
        "%sdecoder/layer_%i/encdec_attention%s" % (prefix, i, postfix_encdec)]
    enc_atts.append(enc_att)
    dec_atts.append(dec_att)
    encdec_atts.append(encdec_att)

  return enc_atts, dec_atts, encdec_atts


import os
from tensor2tensor import problems
from tensor2tensor.bin import t2t_decoder  # To register the hparams set
# from tensor2tensor.utils import registry
from tensor2tensor.utils import trainer_lib
from tensor2tensor.visualization import attention
# from src.visualization import visualization
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

# CHECKPOINT = '/home/usrname/collaboration/t2t_train/translate_envi_iwslt32k/collaboration-collaboration_tiny-v6_0.5_0_normd'
# HParams
problem_name = 'translate_envi_iwslt32k'
data_dir = os.path.expanduser('/home/usrname/collaboration/t2t_data/%s'%(problem_name))
model_name = "collaboration"
hparams_set = "collaboration_tiny"
hparams = 'max_length=128,num_hidden_layers=6,usedegray=0.5,reuse_n=0'
t2t_usr_dir = './src/'   #个人自定义模型代码路径

visualizer = AttentionVisualizer2(hparams_set,hparams, t2t_usr_dir,model_name, data_dir, problem_name, beam_size=1)

#/home/usrname/collaboration/t2t_data/translate_envi_iwslt32k/vocab.translate_envi_iwslt32k.32768.subwords

tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step')

saver = tf.train.Saver()
with tf.Session() as sess:
  # ckpts = tf.train.get_checkpoint_state(CHECKPOINT)
  # ckpt = ckpts.model_checkpoint_path
  ckpt = 'averaged.ckpt-0'  #模型checkpoint
  print(ckpt)
  saver.restore(sess, ckpt)


  input_sentence = "My family was not poor , and myself , I had never experienced hunger ."
  output_string, inp_text, out_text, att_mats = visualizer.get_vis_data_from_string(sess, input_sentence)
  print(output_string)
  print(att_mats)

  attention.show(inp_text, out_text, *att_mats)

  

上一篇:百度飞桨开源业内首个口罩人脸检测及分类模型【转】


下一篇:基于GPU的Pytorch——CNN实现MNIST数据识别