__author__ = 'Guillaume Genthial'
__date__ = '2018-09-22'
See the README.md for an overview of this notebook and its goals.
from distutils.version import LooseVersion
import sys
if LooseVersion(sys.version) < LooseVersion('3.4'):
raise Exception('You need python>=3.4, but you have {}'.format(sys.version))
# Standard
from pathlib import Path
# External
import numpy as np
import tensorflow as tf
if LooseVersion(tf.__version__) < LooseVersion('1.9'):
raise Exception('You need tensorflow>=1.9, but you have {}'.format(tf.__version__))
Compatible with numpy
(similar behavior to pyTorch
). For a full review, see this notebook from the TensorFlow team.
It's a great tool for debugging and allowing dynamic graph building (if you really need it...).
# You need to activate it at program startup
tf.enable_eager_execution()
X = tf.random_normal([2, 4])
h = tf.layers.dense(X, 2, activation=tf.nn.relu)
y = tf.nn.softmax(h)
print(y)
print(y.numpy())
Here, X, h, y
are nodes of the computational graph. But you can actually get the value of these nodes!
In the past you would have written
X = tf.placeholder(dtype=tf.float32, shape=[2, 4])
h = tf.layers.dense(X, 2, activation=tf.nn.relu)
y = tf.nn.softmax(h)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(y, feed_dict={X: np.random.normal(size=[2, 4])})
tf.data
: feeding data into the graph¶tf.placeholders
is replaced by tf.data.Dataset
.
x = tf.placeholder(dtype=tf.int32, shape=[None, 5])
with tf.Session() as sess:
x_eval = sess.run(x, feed_dict={x: x_np})
print(x_eval)
np.array
¶Below is a simple example where we have a np.array
, one row = one example.
x_np = np.array([[i]*5 for i in range(10)])
x_np
We create a Dataset
from this array.
This dataset is a node of the graph.
Each time you query its value, it will move to the next row of the underlying np.array
.
dataset = tf.data.Dataset.from_tensor_slices(x_np)
for el in dataset:
print(el)
el
is the equivalent of the former tf.placeholder
. It's a node of the graph, to which you can apply any Tensorflow operations.
Let's just display the content of the file.
path = 'test.txt'
with Path(path).open() as f:
print(f.read())
The following does just the same as above, but now el
is a tf.Tensor
of dtype=tf.string
!
dataset = tf.data.TextLineDataset([path])
for el in dataset:
print(el)
The best of both worlds, perfect for NLP
It will allow you do put all your logic in pure python, in your generator_fn
, before feeding it to the Graph.
def generator_fn():
for _ in range(2):
yield b'Hello world'
dataset = (tf.data.Dataset.from_generator(
generator_fn,
output_types=(tf.string), # Define type and shape of your generator_fn output
output_shapes=())) # like you would have for your `placeholders`
for el in dataset:
print(el)
tf.data
: Dataset Transforms¶Note: the buffer_size
is the number of elements you load in the RAM before starting to sample from it. If it's too small (1 is no shuffling at all), it won't be efficient. Ideally, your buffer_size
is the same as the number of elements in your dataset. But because not all datasets fit in RAM, you need to be able to set it manually.
dataset = dataset.shuffle(buffer_size=10)
for el in dataset:
print(el)
Repeat your dataset to perform multiple epochs!
dataset = dataset.repeat(2) # 2 epochs
for el in dataset:
print(el)
Note: while map
is super handy when working with images (TensorFlow has a lot of image preprocessing functions and efficiency is crucial), it's not as practical for NLP, because you're now working with tensors. We found it easier to write the most of the preprocessing logic in python, in a generator_fn
, before feeding it to the graph.
dataset = dataset.map(
lambda t: tf.string_split([t], delimiter=' ').values,
num_parallel_calls=4) # Multithreading
for el in dataset:
print(el)
dataset = dataset.batch(batch_size=3)
for el in dataset:
print(el)
In NLP, we usually work with sentences of different length. When building your batch, we need to 'pad', i.e. add some fake elements at the end of the shorter sentences. You can perform this operation easily in TensorFlow.
Here is a dummy example:
def generator_fn():
yield [1, 2]
yield [1, 2, 3]
dataset = tf.data.Dataset.from_generator(
generator_fn,
output_types=(tf.int32),
output_shapes=([None]))
dataset = dataset.padded_batch(
batch_size=2,
padded_shapes=([None]),
padding_values=(4)) # Optional, if not set will default to 0
for el in dataset:
print(el)
Notice that a 4
has been appended at the end of the first row.
And much more: prefetch
, zip
, concatenate
, skip
, take
etc.
See the documentation.
Note: the recommended standard workflow is
shuffle
repeat
(repeat after shuffle so that one epoch = all the examples)map
, using the num_parallel_calls
argument to get multithreading for free.batch
or padded_batch
prefetch
(will prefetch data on the GPU so that it doesn't suffer from any data starvation – and only use 80% of your expensive GPU).This is an example of why using map
is kind of annoying in NLP. It works, but it's not as easy as just using .split()
or any other python code.
def tf_tokenize(t):
return tf.string_split([t], delimiter=' ').values
dataset = tf.data.TextLineDataset(['test.txt'])
dataset = dataset.map(tf_tokenize)
for el in dataset:
print(el)
You're probably used to performing the lookup token -> token_idx
outside TensorFlow. However, tf.contrib.lookup
provides exactly this functionality. It's fast, and when exporting the model for serving, it will consider your vocab.txt
file as a model's resource and keep it with the model!
# One lexeme per line
path_vocab = 'vocab.txt'
with Path(path_vocab).open() as f:
for idx, line in enumerate(f):
print(idx, ' -> ', line.strip())
To use it in TensorFlow:
# The last idx (2) will be used for unknown words
lookup_table = tf.contrib.lookup.index_table_from_file(
path_vocab, num_oov_buckets=1)
for el in dataset:
print(lookup_table.lookup(el))
# We tokenize by white space and assign these ids
tok_to_idx = {'hello': 0, 'world': 1, '<odd>': 2, '<even>': 3}
def tokens_generator():
with Path(path).open() as f:
for line in f:
# Tokenize by white space
tokens = line.strip().split()
token_ids = []
for tok in tokens:
# Look for digits
if tok.isdigit():
if int(tok) % 2 == 0:
tok = '<even>'
else:
tok = '<odd>'
token_ids.append(tok_to_idx.get(tok.lower(), len(tok_to_idx)))
yield (token_ids, len(token_ids))
def get_label(token_id):
if token_id == 2:
return 1
elif token_id == 3:
return 2
else:
return 0
def labels_generator():
for token_ids, _ in tokens_generator():
yield [get_label(tok_id) for tok_id in token_ids]
dataset = tf.data.Dataset.from_generator(
tokens_generator,
output_types=(tf.int32, tf.int32),
output_shapes=([None], ()))
for el in dataset:
print(el)
Let's build a model that predicts the classes 0
, 1
and 2
above.
Test our graph logic here, with eager_execution
activated.
batch_size = 4
vocab_size = 4
dim = 100
shapes = ([None], ())
defaults = (0, 0)
# The last sentence is longer: need padding
dataset = dataset.padded_batch(
batch_size, shapes, defaults)
# Define all variables (In eager execution mode, have to be done just once)
# Otherwise you would create new variable at each loop iteration!
embeddings = tf.get_variable('embeddings', shape=[vocab_size, dim])
lstm_cell = tf.contrib.rnn.LSTMCell(100)
dense_layer = tf.layers.Dense(3, activation=tf.nn.relu)
for tokens, sequence_length in dataset:
token_embeddings = tf.nn.embedding_lookup(embeddings, tokens)
lstm_output, _ = tf.nn.dynamic_rnn(
lstm_cell, token_embeddings, dtype=tf.float32, sequence_length=sequence_length)
logits = dense_layer(lstm_output)
print(logits.shape)
No error/cryptic messages about some shape mismatch – seems like our TensorFlow logic is fine.
tf.estimator
)¶tf.estimator
uses the traditional graph-based environment (no eager execution).
If you use the tf.estimator
interface, you will get for free :
People used to write custom model classes
class Model:
def get_feed_dict(self, X, y):
return {self.X: X, self.y: y}
def build(self):
do_stuff()
def train(self, X, y):
with tf.Session() as sess:
do_some_training()
tf.estimator
¶Now there is a common interface for models.
def input_fn():
# Return a tf.data.Dataset that yields a tuple features, labels
return dataset
def model_fn(features, labels, mode, params):
"""
Parameters
----------
features: tf.Tensor or nested structure
Returned by `input_fn`
labels: tf.Tensor of nested structure
Returned by `input_fn`
mode: tf.estimator.ModeKeys
Either PREDICT / EVAL / TRAIN
params: dict
Hyperparams
Returns
-------
tf.estimator.EstimatorSpec
"""
if mode == tf.estimator.ModeKeys.TRAIN:
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
elif mode == ...
...
estimator = tf.estimator.Estimator(
model_fn=model_fn, params=params)
estimator.train(input_fn)
# Clear all the objects we defined above, to be sure
# we don't mess with anything
tf.reset_default_graph()
input_fn
¶A callable that returns a dataset that yields tuples of features, labels
def input_fn():
# Create datasets for features and labels
dataset_tokens = tf.data.Dataset.from_generator(
tokens_generator,
output_types=(tf.int32, tf.int32),
output_shapes=([None], ()))
dataset_output = tf.data.Dataset.from_generator(
labels_generator,
output_types=(tf.int32),
output_shapes=([None]))
# Zip features and labels in one Dataset
dataset = tf.data.Dataset.zip((dataset_tokens, dataset_output))
# Shuffle, repeat, batch and prefetch
shapes = (([None], ()), [None])
defaults = ((0, 0), 0)
dataset = (dataset
.shuffle(10)
.repeat(100)
.padded_batch(4, shapes, defaults)
.prefetch(1))
# Dataset yields tuple of features, labels
return dataset
model_fn
¶Inputs (features, labels, mode, params)
; returns EstimatorSpec
objects.
def model_fn(features, labels, mode, params):
# Args features and labels are the same as returned by the dataset
tokens, sequence_length = features
# For Serving (ignore this)
if isinstance(features, dict):
tokens = features['tokens']
sequence_length = features['sequence_length']
# 1. Define the graph
vocab_size = params['vocab_size']
dim = params['dim']
embeddings = tf.get_variable('embeddings', shape=[vocab_size, dim])
token_embeddings = tf.nn.embedding_lookup(embeddings, tokens)
lstm_cell = tf.contrib.rnn.LSTMCell(20)
lstm_output, _ = tf.nn.dynamic_rnn(
lstm_cell, token_embeddings, dtype=tf.float32)
logits = tf.layers.dense(lstm_output, 3)
preds = tf.argmax(logits, axis=-1)
# 2. Define EstimatorSpecs for PREDICT
if mode == tf.estimator.ModeKeys.PREDICT:
# Predictions is any nested object (dict is convenient)
predictions = {'logits': logits, 'preds': preds}
# export_outputs is for serving (ignore this)
export_outputs = {
'predictions': tf.estimator.export.PredictOutput(predictions)}
return tf.estimator.EstimatorSpec(mode, predictions=predictions,
export_outputs=export_outputs)
else:
# 3. Define loss and metrics
# Define weights to account for padding
weights = tf.sequence_mask(sequence_length)
loss = tf.losses.sparse_softmax_cross_entropy(
logits=logits, labels=labels, weights=weights)
metrics = {
'accuracy': tf.metrics.accuracy(labels=labels, predictions=preds),
}
# For Tensorboard
for k, v in metrics.items():
# v[1] is the update op of the metrics object
tf.summary.scalar(k, v[1])
# 4. Define EstimatorSpecs for EVAL
# Having an eval mode and metrics in Tensorflow allows you to use
# built-in early stopping (see later)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode, loss=loss,
eval_metric_ops=metrics)
# 5. Define EstimatorSpecs for TRAIN
elif mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
train_op = (tf.train.AdamOptimizer(learning_rate=0.1)
.minimize(loss, global_step=global_step))
return tf.estimator.EstimatorSpec(mode, loss=loss,
train_op=train_op)
What do you think about this model_fn
? It seems like we wrote only things that matter (not a lot of boilerplate!)
Now, let's define our estimator and train it!
params = {
'vocab_size': 4,
'dim': 3
}
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir='model', # Will save the weights here automatically
params=params)
estimator.train(input_fn)
train_and_evaluate
, predict
etc.¶Now, the estimator
is trained, serialized to disk etc. You also have access to TensorBoard. (Lots of stuff for free, without having to write boilerplate code!)
To access tensorboard :
tensorboard --logdir model
Check evaluate
, train_and_evaluate
... documentation.
Example with early stopping, where we run evaluation every 2 minutes (120
seconds).
hook = tf.contrib.estimator.stop_if_no_increase_hook(
estimator, 'accuracy', 500, min_steps=8000, run_every_secs=120)
train_spec = tf.estimator.TrainSpec(input_fn=input_fn, hooks=[hook])
eval_spec = tf.estimator.EvalSpec(input_fn=input_fn, throttle_secs=120)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
# Iterate over the 2 first elements of the (shuffled) dataset and yield predictions
# You need to write variants of your input_fn for eval / predict modes
for idx, predictions in enumerate(estimator.predict(input_fn)):
print(predictions['preds'])
if idx > 0:
break
Exporting an inference graph and the serving signature is "simple" (though the serving_fn
interface could be improved). The cool thing is that once you have your tf.estimator
and your serving_input_fn
, you can just use tensorflow_serving
and get a RESTful API serving your model!
def serving_input_fn():
tokens = tf.placeholder(
dtype=tf.int32, shape=[None, None], name="tokens")
sequence_length = tf.size(tokens)
features = {'tokens': tokens, 'sequence_length': sequence_length}
return tf.estimator.export.ServingInputReceiver(
features=features, receiver_tensors=tokens)
estimator.export_savedmodel('export', serving_input_fn)
docker pull tensorflow/serving
docker run -p 8501:8501 \
--mount type=bind,\
source=path_to_your_export_model,\
target=/models/dummy \
-e MODEL_NAME=dummy -t tensorflow/serving &
curl -d '{"instances": [[0, 1, 2],[0, 1, 3]]}' -X POST \
http://localhost:8501/v1/models/dummy:predict