Create Your Own Visualizations!

Instructions:

  1. Install tensor2tensor and train up a Transformer model following the instruction in the repository https://github.com/tensorflow/tensor2tensor.
  2. Update cell 3 to point to your checkpoint, it is currently set up to read from the default checkpoint location that would be created from following the instructions above.
  3. If you used custom hyper parameters then update cell 4.
  4. Run the notebook!
In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json

import tensorflow as tf
import numpy as np

from tensor2tensor.utils import trainer_utils as utils
from tensor2tensor.visualization import attention
from tensor2tensor.utils import decoding
In [2]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min'
  }
});

Data

In [3]:
import os
# PUT THE MODEL YOU WANT TO LOAD HERE!

PROBLEM = 'translate_ende_wmt32k'
MODEL = 'transformer'
HPARAMS = 'transformer_base_single_gpu'

DATA_DIR=os.path.expanduser('~/t2t_data')
TRAIN_DIR=os.path.expanduser('~/t2t_train/%s/%s-%s' % (PROBLEM, MODEL, HPARAMS))
print(TRAIN_DIR)

FLAGS = tf.flags.FLAGS
FLAGS.problems = PROBLEM
FLAGS.hparams_set = HPARAMS
FLAGS.data_dir = DATA_DIR
FLAGS.model = MODEL

FLAGS.schedule = 'train_and_evaluate'
/usr/local/google/home/llion/t2t_train/translate_ende_wmt32k/transformer-transformer_base_single_gpu
In [4]:
hparams = utils.create_hparams(FLAGS.hparams_set, FLAGS.data_dir)

# SET EXTRA HYPER PARAMS HERE!
#hparams.null_slot = True

utils.add_problem_hparams(hparams, PROBLEM)

num_datashards = utils.devices.data_parallelism().n

mode = tf.estimator.ModeKeys.EVAL

input_fn = utils.input_fn_builder.build_input_fn(
      mode=mode,
      hparams=hparams,
      data_dir=DATA_DIR,
      num_datashards=num_datashards,
      worker_replicas=FLAGS.worker_replicas,
      worker_id=FLAGS.worker_id,
      batch_size=1)

inputs, target = input_fn()
features = inputs
features['targets'] = target
INFO:tensorflow:datashard_devices: ['gpu:0']
INFO:tensorflow:caching_devices: None
INFO:tensorflow:batching_scheme = {'min_length': 0, 'window_size': 720, 'shuffle_queue_size': 270, 'boundaries': [8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 24, 26, 28, 30, 33, 36, 39, 42, 46, 50, 55, 60, 66, 72, 79, 86, 94, 103, 113, 124, 136, 149, 163, 179, 196, 215, 236], 'max_length': 1000000000, 'batch_sizes': [240, 180, 180, 180, 144, 144, 144, 120, 120, 120, 90, 90, 90, 90, 80, 72, 72, 60, 60, 48, 48, 48, 40, 40, 36, 30, 30, 24, 24, 20, 20, 18, 18, 16, 15, 12, 12, 10, 10, 9, 8, 8]}
INFO:tensorflow:Updated batching_scheme = {'min_length': 0, 'window_size': 720, 'shuffle_queue_size': 270, 'boundaries': [], 'max_length': 1000000000, 'batch_sizes': [1]}
INFO:tensorflow:Reading data files from /usr/local/google/home/llion/t2t_data/translate_ende_wmt32k-dev*
In [5]:
def encode(string):
    subtokenizer = hparams.problems[0].vocabulary['inputs']
    return [subtokenizer.encode(string) + [1] + [0]]

def decode(ids):
    return hparams.problems[0].vocabulary['targets'].decode(np.squeeze(ids))

def to_tokens(ids):
    ids = np.squeeze(ids)
    subtokenizer = hparams.problems[0].vocabulary['targets']
    tokens = []
    for _id in ids:
        if _id == 0:
            tokens.append('<PAD>')
        elif _id == 1:
            tokens.append('<EOS>')
        else:
            tokens.append(subtokenizer._subtoken_id_to_subtoken_string(_id))
    return tokens

Model

In [6]:
model_fn=utils.model_builder.build_model_fn(
    MODEL,
    problem_names=[PROBLEM],
    train_steps=FLAGS.train_steps,
    worker_id=FLAGS.worker_id,
    worker_replicas=FLAGS.worker_replicas,
    eval_run_autoregressive=FLAGS.eval_run_autoregressive,
    decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams))
est_spec = model_fn(features, target, mode, hparams)
INFO:tensorflow:datashard_devices: ['gpu:0']
INFO:tensorflow:caching_devices: None
INFO:tensorflow:Doing model_fn_body took 1.881 sec.
INFO:tensorflow:This model_fn took 2.023 sec.
In [7]:
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
    beam_out = model_fn(features, target, tf.contrib.learn.ModeKeys.INFER, hparams)
INFO:tensorflow:datashard_devices: ['gpu:0']
INFO:tensorflow:caching_devices: None
INFO:tensorflow:Beam Decoding with beam size 4
INFO:tensorflow:Doing model_fn_body took 1.393 sec.
INFO:tensorflow:This model_fn took 1.504 sec.

Session

In [8]:
sv = tf.train.Supervisor(
    logdir=TRAIN_DIR,
    global_step=tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step'))
sess = sv.PrepareSession(config=tf.ConfigProto(allow_soft_placement=True))
sv.StartQueueRunners(
    sess,
    tf.get_default_graph().get_collection(tf.GraphKeys.QUEUE_RUNNERS))
INFO:tensorflow:Restoring parameters from /usr/local/google/home/llion/t2t_train/translate_ende_wmt32k/transformer-transformer_base_single_gpu/model.ckpt-1
INFO:tensorflow:Starting standard services.
INFO:tensorflow:Starting queue runners.
INFO:tensorflow:Saving checkpoint to path /usr/local/google/home/llion/t2t_train/translate_ende_wmt32k/transformer-transformer_base_single_gpu/model.ckpt
Out[8]:
[]

Visualization

In [9]:
# Get the attention tensors from the graph.
# This need to be done using the training graph since the inference uses a tf.while_loop
# and you cant fetch tensors from inside a while_loop.

enc_atts = []
dec_atts = []
encdec_atts = []

for i in range(hparams.num_hidden_layers):
    enc_att = tf.get_default_graph().get_operation_by_name(
        "body/model/parallel_0/body/encoder/layer_%i/self_attention/multihead_attention/dot_product_attention/attention_weights" % i).values()[0]
    dec_att = tf.get_default_graph().get_operation_by_name(
        "body/model/parallel_0/body/decoder/layer_%i/self_attention/multihead_attention/dot_product_attention/attention_weights" % i).values()[0]
    encdec_att = tf.get_default_graph().get_operation_by_name(
        "body/model/parallel_0/body/decoder/layer_%i/encdec_attention/multihead_attention/dot_product_attention/attention_weights" % i).values()[0]

    enc_atts.append(enc_att)
    dec_atts.append(dec_att)
    encdec_atts.append(encdec_att)

Test translation from the dataset

In [10]:
inp, out, logits = sess.run([inputs['inputs'], target, est_spec.predictions['predictions']])

print("Input:    ", decode(inp[0]))
print("Gold:     ", decode(out[0]))
logits = np.squeeze(logits[0])
tokens = np.argmax(logits, axis=1)
print("Gold out: ", decode(tokens))
INFO:tensorflow:global_step/sec: 0
Input:     For example, during the 2008 general election in Florida, 33% of early voters were African-Americans, who accounted however for only 13% of voters in the State.
Gold:      Beispielsweise waren bei den allgemeinen Wahlen 2008 in Florida 33% der Wähler, die im Voraus gewählt haben, Afro-Amerikaner, obwohl sie nur 13% der Wähler des Bundesstaates ausmachen.
Gold out:  So waren 33 den allgemeinen Wahlen im in der a 33 % der Frühjungdie nur Land die wurden, die ro- Amerikaner, die sie nur 13 % der Wähler im Staates staats betra.
INFO:tensorflow:Recording summary at step 250000.

Visualize Custom Sentence

In [11]:
eng = "I have three dogs."
In [12]:
inp_ids = encode(eng)
beam_decode = sess.run(beam_out.predictions['outputs'], {
    inputs['inputs']: np.expand_dims(np.expand_dims(inp_ids, axis=2), axis=3),
})
trans = decode(beam_decode[0])
print(trans)
Ich habe drei Hunde.
In [13]:
output_ids = beam_decode

# Get attentions
np_enc_atts, np_dec_atts, np_encdec_atts = sess.run([enc_atts, dec_atts, encdec_atts], {
    inputs['inputs']: np.expand_dims(np.expand_dims(inp_ids, axis=2), axis=3),
    target: np.expand_dims(np.expand_dims(output_ids, axis=2), axis=3),
})
In [14]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

Interpreting the Visualizations

  • The layers drop down allow you to view the different Transformer layers, 0-indexed of course.
    • Tip: The first layer, last layer and 2nd to last layer are usually the most interpretable.
  • The attention dropdown allows you to select different pairs of encoder-decoder attentions:
    • All: Shows all types of attentions together. NOTE: There is no relation between heads of the same color - between the decoder self attention and decoder-encoder attention since they do not share parameters.
    • Input - Input: Shows only the encoder self-attention.
    • Input - Output: Shows the decoder’s attention on the encoder. NOTE: Every decoder layer attends to the final layer of encoder so the visualization will show the attention on the final encoder layer regardless of what layer is selected in the drop down.
    • Output - Output: Shows only the decoder self-attention. NOTE: The visualization might be slightly misleading in the first layer since the text shown is the target of the decoder, the input to the decoder at layer 0 is this text with a GO symbol prepreded.
  • The colored squares represent the different attention heads.
    • You can hide or show a given head by clicking on it’s color.
    • Double clicking a color will hide all other colors, double clicking on a color when it’s the only head showing will show all the heads again.
  • You can hover over a word to see the individual attention weights for just that position.
    • Hovering over the words on the left will show what that position attended to.
    • Hovering over the words on the right will show what positions attended to it.
In [ ]:
inp_text = to_tokens(inp_ids)
out_text = to_tokens(output_ids)

attention.show(inp_text, out_text, np_enc_atts, np_dec_atts, np_encdec_atts)