#!/usr/bin/env python # coding: utf-8 # # Good practices in Modern Tensorflow for NLP # In[1]: __author__ = 'Guillaume Genthial' __date__ = '2018-09-22' # ## Contents # # See the [README.md](README.md) for an overview of this notebook and its goals. # # 0. [Eager execution](#Eager-execution) # 0. [`tf.data`: feeding data into the graph](#tf.data:-feeding-data-into-the-graph) # 0. [Placeholders (before)](#Placeholders-(before)) # 0. [Dataset from `np.array`](#Dataset-from-np.array) # 0. [Dataset from text file](#Dataset-from-text-file) # 0. [Dataset from custom generator](#Dataset-from-custom-generator) # 0. [`tf.data`: Dataset Transforms](#tf.data:-Dataset-Transforms) # 0. [Shuffle](#Shuffle) # 0. [Repeat](#Repeat) # 0. [Map](#Map) # 0. [Batch](#Batch) # 0. [Padded batch](#Padded-batch) # 0. [NLP: preprocessing in Tensorflow](#NLP:-preprocessing-in-Tensorflow) # 0. [Tokenizing by white space in TensorFlow](#Tokenizing-by-white-space-in-TensorFlow) # 0. [Lookup token index from vocab file in TensorFlow](#Lookup-token-index-from-vocab-file-in-TensorFlow) # 0. [Full Example](#Full-Example) # 0. [Task and Data](#Task-and-Data) # 0. [Graph (test with eager execution)](#Graph-(test-with-eager-execution)) # 0. [Model (`tf.estimator`)](#Model-(tf.estimator)) # 0. [Before: custom model classes](#Before:-custom-model-classes) # 0. [Now: `tf.estimator`](#Now:--tf.estimator) # 0. [`input_fn`](#input_fn) # 0. [`model_fn`](#model_fn) # 0. [Instantiate and train your Estimator](#Instantiate-and-train-your-Estimator) # 0. [TensorBoard, `train_and_evaluate`, `predict` etc.](#TensorBoard,-train_and_evaluate,-predict-etc.) # 0. [A word about TensorFlow model serving](#A-word-about-TensorFlow-model-serving) # 0. [Serving interface](#Serving-interface) # 0. [Docker Image](#Docker-Image) # 0. [Pull existing image](#Pull-existing-image) # 0. [Run](#Run) # 0. [Rest API POST with curl](#Rest-API-POST-with-curl) # ## Setup # In[2]: from distutils.version import LooseVersion import sys if LooseVersion(sys.version) < LooseVersion('3.4'): raise Exception('You need python>=3.4, but you have {}'.format(sys.version)) # In[3]: # Standard from pathlib import Path # External import numpy as np import tensorflow as tf # In[4]: if LooseVersion(tf.__version__) < LooseVersion('1.9'): raise Exception('You need tensorflow>=1.9, but you have {}'.format(tf.__version__)) # ## Eager execution # # Compatible with `numpy` (similar behavior to `pyTorch`). For a full review, see [this notebook from the TensorFlow team](https://colab.research.google.com/github/tensorflow/tensorflow/blob/r1.10/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb). # # It's a great tool for __debugging__ and allowing dynamic graph building (if you really need it...). # In[5]: # You need to activate it at program startup tf.enable_eager_execution() # In[6]: X = tf.random_normal([2, 4]) h = tf.layers.dense(X, 2, activation=tf.nn.relu) y = tf.nn.softmax(h) print(y) print(y.numpy()) # Here, `X, h, y` are nodes of the computational graph. But you can actually get the value of these nodes! # # In the past you would have written # # ```python # X = tf.placeholder(dtype=tf.float32, shape=[2, 4]) # h = tf.layers.dense(X, 2, activation=tf.nn.relu) # y = tf.nn.softmax(h) # with tf.Session() as sess: # sess.run(tf.global_variables_initializer()) # sess.run(y, feed_dict={X: np.random.normal(size=[2, 4])}) # ``` # ## `tf.data`: feeding data into the graph # # `tf.placeholders` is replaced by `tf.data.Dataset`. # # ### Placeholders (before) # # ```python # x = tf.placeholder(dtype=tf.int32, shape=[None, 5]) # with tf.Session() as sess: # x_eval = sess.run(x, feed_dict={x: x_np}) # print(x_eval) # ``` # # ### Dataset from `np.array` # # Below is a simple example where we have a `np.array`, one row = one example. # In[7]: x_np = np.array([[i]*5 for i in range(10)]) x_np # We create a `Dataset` from this array. # # This dataset is a node of the graph. # Each time you query its value, it will move to the next row of the underlying `np.array`. # In[8]: dataset = tf.data.Dataset.from_tensor_slices(x_np) # In[9]: for el in dataset: print(el) # `el` is the equivalent of the former `tf.placeholder`. It's a node of the graph, to which you can apply any Tensorflow operations. # ### Dataset from text file # # Let's just display the content of the file. # In[10]: path = 'test.txt' with Path(path).open() as f: print(f.read()) # The following does just the same as above, but now `el` is a `tf.Tensor` of `dtype=tf.string`! # In[11]: dataset = tf.data.TextLineDataset([path]) for el in dataset: print(el) # ### Dataset from custom generator # # __The best of both worlds, perfect for NLP__ # # It will allow you do put all your logic in pure python, in your `generator_fn`, before feeding it to the Graph. # In[12]: def generator_fn(): for _ in range(2): yield b'Hello world' # In[13]: dataset = (tf.data.Dataset.from_generator( generator_fn, output_types=(tf.string), # Define type and shape of your generator_fn output output_shapes=())) # like you would have for your `placeholders` # In[14]: for el in dataset: print(el) # ## `tf.data`: Dataset Transforms # ### Shuffle # # Note: the `buffer_size` is the number of elements you load in the RAM before starting to sample from it. If it's too small (1 is no shuffling at all), it won't be efficient. Ideally, your `buffer_size` is the same as the number of elements in your dataset. But because not all datasets fit in RAM, you need to be able to set it manually. # In[15]: dataset = dataset.shuffle(buffer_size=10) # In[16]: for el in dataset: print(el) # ### Repeat # # Repeat your dataset to perform multiple epochs! # In[17]: dataset = dataset.repeat(2) # 2 epochs # In[18]: for el in dataset: print(el) # ### Map # # Note: while `map` is super handy when working with images (TensorFlow has a lot of image preprocessing functions and efficiency is crucial), it's not as practical for NLP, because you're now working with tensors. We found it easier to write the most of the preprocessing logic in python, in a `generator_fn`, before feeding it to the graph. # In[19]: dataset = dataset.map( lambda t: tf.string_split([t], delimiter=' ').values, num_parallel_calls=4) # Multithreading # In[20]: for el in dataset: print(el) # ### Batch # In[21]: dataset = dataset.batch(batch_size=3) # In[22]: for el in dataset: print(el) # ### Padded batch # # In NLP, we usually work with sentences of different length. When building your batch, we need to 'pad', i.e. add some fake elements at the end of the shorter sentences. You can perform this operation easily in TensorFlow. # # Here is a dummy example: # In[23]: def generator_fn(): yield [1, 2] yield [1, 2, 3] dataset = tf.data.Dataset.from_generator( generator_fn, output_types=(tf.int32), output_shapes=([None])) # In[24]: dataset = dataset.padded_batch( batch_size=2, padded_shapes=([None]), padding_values=(4)) # Optional, if not set will default to 0 # In[25]: for el in dataset: print(el) # Notice that a `4` has been appended at the end of the first row. # # And much more: `prefetch`, `zip`, `concatenate`, `skip`, `take` etc. # See the [documentation](https://www.tensorflow.org/api_docs/python/tf/data/Dataset). # # Note: the recommended standard workflow is # # 1. `shuffle` # 2. `repeat` (repeat after shuffle so that one epoch = all the examples) # 3. `map`, using the `num_parallel_calls` argument to get multithreading for free. # 4. `batch` or `padded_batch` # 5. `prefetch` (will prefetch data on the GPU so that it doesn't suffer from any data starvation – and only use 80% of your expensive GPU). # ## NLP: preprocessing in Tensorflow # ### Tokenizing by white space in TensorFlow # # This is an example of why using `map` is kind of annoying in NLP. It works, but it's not as easy as just using `.split()` or any other python code. # In[26]: def tf_tokenize(t): return tf.string_split([t], delimiter=' ').values # In[27]: dataset = tf.data.TextLineDataset(['test.txt']) dataset = dataset.map(tf_tokenize) for el in dataset: print(el) # ### Lookup token index from vocab file in TensorFlow # # You're probably used to performing the lookup `token -> token_idx` outside TensorFlow. However, `tf.contrib.lookup` provides exactly this functionality. It's fast, and when exporting the model for serving, it will consider your `vocab.txt` file as a model's resource and keep it with the model! # In[28]: # One lexeme per line path_vocab = 'vocab.txt' with Path(path_vocab).open() as f: for idx, line in enumerate(f): print(idx, ' -> ', line.strip()) # To use it in TensorFlow: # In[29]: # The last idx (2) will be used for unknown words lookup_table = tf.contrib.lookup.index_table_from_file( path_vocab, num_oov_buckets=1) # In[30]: for el in dataset: print(lookup_table.lookup(el)) # ## Full Example # # ### Task and Data # # The `tokens_generator` returns list of ids. We map even/odd numbers to 2 different ids. # # The `labels_generator` returns list of label ids. We want to predict if a token is # - a word (label `0`) # - an odd number (label `1`) # - an even number (label `2`) # In[31]: # We tokenize by white space and assign these ids tok_to_idx = {'hello': 0, 'world': 1, '': 2, '': 3} def tokens_generator(): with Path(path).open() as f: for line in f: # Tokenize by white space tokens = line.strip().split() token_ids = [] for tok in tokens: # Look for digits if tok.isdigit(): if int(tok) % 2 == 0: tok = '' else: tok = '' token_ids.append(tok_to_idx.get(tok.lower(), len(tok_to_idx))) yield (token_ids, len(token_ids)) def get_label(token_id): if token_id == 2: return 1 elif token_id == 3: return 2 else: return 0 def labels_generator(): for token_ids, _ in tokens_generator(): yield [get_label(tok_id) for tok_id in token_ids] # In[32]: dataset = tf.data.Dataset.from_generator( tokens_generator, output_types=(tf.int32, tf.int32), output_shapes=([None], ())) for el in dataset: print(el) # ### Graph (test with eager execution) # # Let's build a model that predicts the classes `0`, `1` and `2` above. # # Test our graph logic here, with `eager_execution` activated. # In[33]: batch_size = 4 vocab_size = 4 dim = 100 # In[34]: shapes = ([None], ()) defaults = (0, 0) # The last sentence is longer: need padding dataset = dataset.padded_batch( batch_size, shapes, defaults) # In[35]: # Define all variables (In eager execution mode, have to be done just once) # Otherwise you would create new variable at each loop iteration! embeddings = tf.get_variable('embeddings', shape=[vocab_size, dim]) lstm_cell = tf.contrib.rnn.LSTMCell(100) dense_layer = tf.layers.Dense(3, activation=tf.nn.relu) # In[36]: for tokens, sequence_length in dataset: token_embeddings = tf.nn.embedding_lookup(embeddings, tokens) lstm_output, _ = tf.nn.dynamic_rnn( lstm_cell, token_embeddings, dtype=tf.float32, sequence_length=sequence_length) logits = dense_layer(lstm_output) print(logits.shape) # No error/cryptic messages about some shape mismatch – seems like our TensorFlow logic is fine. # ### Model (`tf.estimator`) # # `tf.estimator` uses the traditional graph-based environment (no eager execution). # # If you use the `tf.estimator` interface, you will get for free : # # 1. Tensorboard # 2. Weights serialization # 3. Logging # 4. Model export for serving # 5. Unified structure compatible with open-source code # # #### Before: custom model classes # People used to write custom model classes # # ```python # class Model: # # def get_feed_dict(self, X, y): # return {self.X: X, self.y: y} # # def build(self): # do_stuff() # # def train(self, X, y): # with tf.Session() as sess: # do_some_training() # ``` # # # #### Now: `tf.estimator` # # Now there is a common interface for models. # # ```python # def input_fn(): # # Return a tf.data.Dataset that yields a tuple features, labels # return dataset # # def model_fn(features, labels, mode, params): # """ # Parameters # ---------- # features: tf.Tensor or nested structure # Returned by `input_fn` # labels: tf.Tensor of nested structure # Returned by `input_fn` # mode: tf.estimator.ModeKeys # Either PREDICT / EVAL / TRAIN # params: dict # Hyperparams # # Returns # ------- # tf.estimator.EstimatorSpec # """ # if mode == tf.estimator.ModeKeys.TRAIN: # return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) # elif mode == ... # ... # # estimator = tf.estimator.Estimator( # model_fn=model_fn, params=params) # estimator.train(input_fn)``` # In[37]: # Clear all the objects we defined above, to be sure # we don't mess with anything tf.reset_default_graph() # #### `input_fn` # # A callable that returns a dataset that yields tuples of features, labels # In[38]: def input_fn(): # Create datasets for features and labels dataset_tokens = tf.data.Dataset.from_generator( tokens_generator, output_types=(tf.int32, tf.int32), output_shapes=([None], ())) dataset_output = tf.data.Dataset.from_generator( labels_generator, output_types=(tf.int32), output_shapes=([None])) # Zip features and labels in one Dataset dataset = tf.data.Dataset.zip((dataset_tokens, dataset_output)) # Shuffle, repeat, batch and prefetch shapes = (([None], ()), [None]) defaults = ((0, 0), 0) dataset = (dataset .shuffle(10) .repeat(100) .padded_batch(4, shapes, defaults) .prefetch(1)) # Dataset yields tuple of features, labels return dataset # #### `model_fn` # # Inputs `(features, labels, mode, params)`; returns `EstimatorSpec` objects. # In[39]: def model_fn(features, labels, mode, params): # Args features and labels are the same as returned by the dataset tokens, sequence_length = features # For Serving (ignore this) if isinstance(features, dict): tokens = features['tokens'] sequence_length = features['sequence_length'] # 1. Define the graph vocab_size = params['vocab_size'] dim = params['dim'] embeddings = tf.get_variable('embeddings', shape=[vocab_size, dim]) token_embeddings = tf.nn.embedding_lookup(embeddings, tokens) lstm_cell = tf.contrib.rnn.LSTMCell(20) lstm_output, _ = tf.nn.dynamic_rnn( lstm_cell, token_embeddings, dtype=tf.float32) logits = tf.layers.dense(lstm_output, 3) preds = tf.argmax(logits, axis=-1) # 2. Define EstimatorSpecs for PREDICT if mode == tf.estimator.ModeKeys.PREDICT: # Predictions is any nested object (dict is convenient) predictions = {'logits': logits, 'preds': preds} # export_outputs is for serving (ignore this) export_outputs = { 'predictions': tf.estimator.export.PredictOutput(predictions)} return tf.estimator.EstimatorSpec(mode, predictions=predictions, export_outputs=export_outputs) else: # 3. Define loss and metrics # Define weights to account for padding weights = tf.sequence_mask(sequence_length) loss = tf.losses.sparse_softmax_cross_entropy( logits=logits, labels=labels, weights=weights) metrics = { 'accuracy': tf.metrics.accuracy(labels=labels, predictions=preds), } # For Tensorboard for k, v in metrics.items(): # v[1] is the update op of the metrics object tf.summary.scalar(k, v[1]) # 4. Define EstimatorSpecs for EVAL # Having an eval mode and metrics in Tensorflow allows you to use # built-in early stopping (see later) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics) # 5. Define EstimatorSpecs for TRAIN elif mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() train_op = (tf.train.AdamOptimizer(learning_rate=0.1) .minimize(loss, global_step=global_step)) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) # What do you think about this `model_fn`? It seems like we wrote only things that matter (not a lot of boilerplate!) # # # #### Instantiate and train your Estimator # # Now, let's define our estimator and train it! # In[40]: params = { 'vocab_size': 4, 'dim': 3 } estimator = tf.estimator.Estimator( model_fn=model_fn, model_dir='model', # Will save the weights here automatically params=params) # In[41]: estimator.train(input_fn) # #### TensorBoard, `train_and_evaluate`, `predict` etc. # # Now, the `estimator` is trained, serialized to disk etc. You also have access to TensorBoard. (Lots of stuff for free, without having to write boilerplate code!) # # To access tensorboard : # # ``` # tensorboard --logdir model # ``` # Check `evaluate`, `train_and_evaluate`... [documentation](https://www.tensorflow.org/api_docs/python/tf/estimator/train_and_evaluate). # # Example with early stopping, where we run evaluation every 2 minutes (`120` seconds). # # ```python # hook = tf.contrib.estimator.stop_if_no_increase_hook( # estimator, 'accuracy', 500, min_steps=8000, run_every_secs=120) # train_spec = tf.estimator.TrainSpec(input_fn=input_fn, hooks=[hook]) # eval_spec = tf.estimator.EvalSpec(input_fn=input_fn, throttle_secs=120) # tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) # ``` # In[42]: # Iterate over the 2 first elements of the (shuffled) dataset and yield predictions # You need to write variants of your input_fn for eval / predict modes for idx, predictions in enumerate(estimator.predict(input_fn)): print(predictions['preds']) if idx > 0: break # ## A word about TensorFlow model serving # Exporting an inference graph and the serving signature is "simple" (though the `serving_fn` interface could be improved). The cool thing is that once you have your `tf.estimator` and your `serving_input_fn`, you can just use `tensorflow_serving` and get a RESTful API serving your model! # ### Serving interface # In[43]: def serving_input_fn(): tokens = tf.placeholder( dtype=tf.int32, shape=[None, None], name="tokens") sequence_length = tf.size(tokens) features = {'tokens': tokens, 'sequence_length': sequence_length} return tf.estimator.export.ServingInputReceiver( features=features, receiver_tensors=tokens) # In[44]: estimator.export_savedmodel('export', serving_input_fn) # ### Docker Image # # # #### Pull existing image # # ``` # docker pull tensorflow/serving # ``` # # #### Run # # ``` # docker run -p 8501:8501 \ # --mount type=bind,\ # source=path_to_your_export_model,\ # target=/models/dummy \ # -e MODEL_NAME=dummy -t tensorflow/serving & # ``` # # #### Rest API POST with curl # # ``` # curl -d '{"instances": [[0, 1, 2],[0, 1, 3]]}' -X POST \ # http://localhost:8501/v1/models/dummy:predict # ```