Multi-task learning TensorFlow with the Head API

This notebook is a companion to the following blog post: https://iaml.it/blog/multitask-learning-tensorflow/

Prerequisites

In [1]:
# Install all necessary packages
!pip install tensorflow tqdm requests --upgrade
Collecting tensorflow
  Downloading tensorflow-1.7.0-cp36-cp36m-manylinux1_x86_64.whl (48.0MB)
    100% |████████████████████████████████| 48.0MB 29kB/s 
Collecting tqdm
  Downloading tqdm-4.21.0-py2.py3-none-any.whl (42kB)
    100% |████████████████████████████████| 51kB 10.9MB/s 
Requirement already up-to-date: requests in /usr/local/lib/python3.6/dist-packages
Requirement already up-to-date: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow)
Requirement already up-to-date: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow)
Requirement already up-to-date: absl-py>=0.1.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow)
Collecting grpcio>=1.8.6 (from tensorflow)
  Downloading grpcio-1.10.1-cp36-cp36m-manylinux1_x86_64.whl (7.7MB)
    100% |████████████████████████████████| 7.7MB 182kB/s 
Requirement already up-to-date: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow)
Collecting tensorboard<1.8.0,>=1.7.0 (from tensorflow)
  Downloading tensorboard-1.7.0-py3-none-any.whl (3.1MB)
    100% |████████████████████████████████| 3.1MB 441kB/s 
Collecting wheel>=0.26 (from tensorflow)
  Downloading wheel-0.31.0-py2.py3-none-any.whl (41kB)
    100% |████████████████████████████████| 51kB 11.1MB/s 
Collecting protobuf>=3.4.0 (from tensorflow)
  Downloading protobuf-3.5.2.post1-cp36-cp36m-manylinux1_x86_64.whl (6.4MB)
    100% |████████████████████████████████| 6.4MB 207kB/s 
Requirement already up-to-date: gast>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow)
Requirement already up-to-date: numpy>=1.13.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow)
Requirement already up-to-date: idna<2.7,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests)
Requirement already up-to-date: urllib3<1.23,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests)
Requirement already up-to-date: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests)
Requirement already up-to-date: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests)
Requirement already up-to-date: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.8.0,>=1.7.0->tensorflow)
Requirement already up-to-date: html5lib==0.9999999 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.8.0,>=1.7.0->tensorflow)
Requirement already up-to-date: werkzeug>=0.11.10 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.8.0,>=1.7.0->tensorflow)
Requirement already up-to-date: bleach==1.5.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.8.0,>=1.7.0->tensorflow)
Collecting setuptools (from protobuf>=3.4.0->tensorflow)
  Downloading setuptools-39.0.1-py2.py3-none-any.whl (569kB)
    100% |████████████████████████████████| 573kB 2.3MB/s 
Installing collected packages: setuptools, protobuf, grpcio, wheel, tensorboard, tensorflow, tqdm
  Found existing installation: setuptools 36.2.7
    Not uninstalling setuptools at /usr/lib/python3/dist-packages, outside environment /usr
  Found existing installation: protobuf 3.5.2
    Uninstalling protobuf-3.5.2:
      Successfully uninstalled protobuf-3.5.2
  Found existing installation: grpcio 1.10.0
    Uninstalling grpcio-1.10.0:
      Successfully uninstalled grpcio-1.10.0
  Found existing installation: wheel 0.30.0
    Uninstalling wheel-0.30.0:
      Successfully uninstalled wheel-0.30.0
  Found existing installation: tensorboard 1.6.0
    Uninstalling tensorboard-1.6.0:
      Successfully uninstalled tensorboard-1.6.0
  Found existing installation: tensorflow 1.6.0
    Uninstalling tensorflow-1.6.0:
      Successfully uninstalled tensorflow-1.6.0
Successfully installed grpcio-1.10.1 protobuf-3.5.2.post1 setuptools-39.0.1 tensorboard-1.7.0 tensorflow-1.7.0 tqdm-4.21.0 wheel-0.31.0

Data download

In [2]:
# Dataset is taken from here: http://mmlab.ie.cuhk.edu.hk/projects/TCDCN.html
# Small code is taken from this StackOverflow thread: https://stackoverflow.com/questions/22676/how-do-i-download-a-file-over-http-using-python
# This could take a while!

from tqdm import tqdm
import requests

url = "http://mmlab.ie.cuhk.edu.hk/projects/TCDCN/data/MTFL.zip"
response = requests.get(url, stream=True)

with open("MTFL", "wb") as handle:
    for data in tqdm(response.iter_content(), unit=' KB'):
        handle.write(data)
150152960 KB [25:41, 97417.50 KB/s]
In [0]:
# Unzip all files
import zipfile
zip_ref = zipfile.ZipFile('MTFL', 'r')
zip_ref.extractall()
zip_ref.close()
In [4]:
!ls
AFLW  datalab  lfw_5590  MTFL  net_7876  readme.txt  testing.txt  training.txt

Data loading in Pandas

In [0]:
# Import dataset in Pandas
import pandas as pd
train_data = pd.read_csv('training.txt', sep=' ', header=None, skipinitialspace=True, nrows=10000)
test_data = pd.read_csv('testing.txt', sep=' ', header=None, skipinitialspace=True, nrows=2995)
In [7]:
train_data.iloc[0]
Out[7]:
0     lfw_5590\Aaron_Eckhart_0001.jpg
1                              107.25
2                              147.75
3                              126.25
4                              106.25
5                              140.75
6                              108.75
7                              113.25
8                              143.75
9                              158.75
10                             162.75
11                                  1
12                                  2
13                                  2
14                                  3
Name: 0, dtype: object
In [0]:
train_data.iloc[:, 0] = train_data.iloc[:, 0].apply(lambda s: s.replace('\\', '/')) # Needed for filename convention
test_data.iloc[:, 0] = test_data.iloc[:, 0].apply(lambda s: s.replace('\\', '/')) # Needed for filename convention
In [0]:
from sklearn import preprocessing
train_data.iloc[:, 1:11] = preprocessing.MinMaxScaler().fit_transform(train_data.iloc[:, 1:11])
test_data.iloc[:, 1:11] = preprocessing.MinMaxScaler().fit_transform(test_data.iloc[:, 1:11])

Data loading with tf.data

In [0]:
import numpy as np
import tensorflow as tf
In [23]:
# Example code for handling datasets

filenames = tf.constant(train_data.iloc[:, 0].tolist())
labels = tf.constant(train_data.iloc[:, 1:].values)

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))

it = dataset.batch(64).make_one_shot_iterator().get_next()

with tf.Session() as sess:
  (imgs, labels) = sess.run(it)
  print(imgs[0])
b'lfw_5590/Aaron_Eckhart_0001.jpg'
In [0]:
# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
  image_string = tf.read_file(filename) 
  image_decoded = tf.image.decode_jpeg(image_string, channels=3) # Channels needed because some test images are b/w
  image_resized = tf.image.resize_images(image_decoded, [40, 40])
  return {"x": image_resized}, label
In [0]:
# This snippet is adapted from here: https://www.tensorflow.org/programmers_guide/datasets

def input_fn(data, is_eval=False):

  # Path delle immagini
  filenames = tf.constant(data.iloc[:, 0].tolist())

  # Etichette delle immagini
  labels = tf.constant(data.iloc[:, 1:].values.astype(np.float32))

  # Costruisco il dataset
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.map(_parse_function)

  # Logica di training / testing
  if is_eval:
    dataset = dataset.batch(64)
  else:
    dataset = dataset.repeat().shuffle(1000).batch(64)
    
  # Costruisco l'iteratore
  return dataset.make_one_shot_iterator().get_next()
  
  #for (filename, label) in tfe.Iterator(dataset):
  #  d = _parse_function(filename, label)
In [63]:
import matplotlib.pyplot as plt
with tf.Session() as sess:
  (imgs, labels) = sess.run(input_fn(train_data, True))
  plt.imshow(imgs["x"][0] / 255)
  print(labels[0])
[0.33482143 0.32603687 0.3471564  0.3612805  0.2852697  0.4357639
 0.47532895 0.41169155 0.35       0.36334747 1.         2.
 2.         3.        ]

Standard classical estimator (single-task only!)

In [0]:
def extract_features(features):
  # Input Layer
  input_layer = tf.reshape(features["x"], [-1, 40, 40, 3])

  # Primo layer convolutivo
  conv1 = tf.layers.conv2d(inputs=input_layer, filters=16, kernel_size=[5, 5], padding="same", activation=tf.nn.relu)
  pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

  # Secondo layer convolutivo
  conv2 = tf.layers.conv2d(inputs=pool1, filters=48, kernel_size=[3, 3], padding="same", activation=tf.nn.relu)
  pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

  # Terzo layer convolutivo
  conv3 = tf.layers.conv2d(inputs=pool2, filters=64, kernel_size=[3, 3], padding="same", activation=tf.nn.relu)
  pool3 = tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)
  
  # Quarto layer convolutivo
  conv4 = tf.layers.conv2d(inputs=pool3, filters=64, kernel_size=[2, 2], padding="same", activation=tf.nn.relu)
  
  # Dense Layer
  flat = tf.reshape(conv4, [-1, 5 * 5 * 64])
  dense = tf.layers.dense(inputs=flat, units=100, activation=tf.nn.relu)
  
  return dense
In [0]:
# Adapted from here: https://www.tensorflow.org/tutorials/layers

def single_task_cnn_model_fn(features, labels, mode):
  
  dense = extract_features(features)
  
  # Predizioni
  predictions = tf.layers.dense(inputs=dense, units=2)

  outputs = {
      "predictions": predictions
  }

  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(mode=mode, predictions=outputs)

  # Funzione costo (errore quadratico medio)
  loss = tf.losses.mean_squared_error(labels=labels[:, 2:8:5], predictions=predictions)

  # Ottimizzazione
  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.AdamOptimizer()
    train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

  # Valutazione del modello
  eval_metric_ops = {
      "rmse": tf.metrics.root_mean_squared_error(
          labels=labels[:, 2:8:5], predictions=outputs["predictions"])}
  return tf.estimator.EstimatorSpec(
      mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
In [189]:
# Create the Estimator
single_task_classifier = tf.estimator.Estimator(
    model_fn=single_task_cnn_model_fn, model_dir="/tmp/cnn_nose")
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/cnn_nose', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fadecf4d278>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
In [190]:
# Train the model
single_task_classifier.train(input_fn=lambda: input_fn(train_data), steps=1000)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 1 into /tmp/cnn_nose/model.ckpt.
INFO:tensorflow:loss = 168.66652, step = 1
INFO:tensorflow:global_step/sec: 3.21662
INFO:tensorflow:loss = 0.07332276, step = 101 (31.090 sec)
INFO:tensorflow:global_step/sec: 3.4082
INFO:tensorflow:loss = 0.029060293, step = 201 (29.342 sec)
INFO:tensorflow:global_step/sec: 3.24658
INFO:tensorflow:loss = 0.04251621, step = 301 (30.803 sec)
INFO:tensorflow:global_step/sec: 3.27862
INFO:tensorflow:loss = 0.042185806, step = 401 (30.498 sec)
INFO:tensorflow:global_step/sec: 3.36051
INFO:tensorflow:loss = 0.017348655, step = 501 (29.760 sec)
INFO:tensorflow:global_step/sec: 3.23404
INFO:tensorflow:loss = 0.04518965, step = 601 (30.920 sec)
INFO:tensorflow:global_step/sec: 3.31534
INFO:tensorflow:loss = 0.043051306, step = 701 (30.163 sec)
INFO:tensorflow:global_step/sec: 3.31508
INFO:tensorflow:loss = 0.014098421, step = 801 (30.163 sec)
INFO:tensorflow:global_step/sec: 3.2097
INFO:tensorflow:loss = 0.051925942, step = 901 (31.157 sec)
INFO:tensorflow:Saving checkpoints for 1000 into /tmp/cnn_nose/model.ckpt.
INFO:tensorflow:Loss for final step: 0.039091345.
Out[190]:
<tensorflow.python.estimator.estimator.Estimator at 0x7fade6fb1978>
In [191]:
single_task_classifier.evaluate(input_fn=lambda: input_fn(test_data, is_eval=True))
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-04-10-13:32:48
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/cnn_nose/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-04-10-13:32:54
INFO:tensorflow:Saving dict for global step 1000: global_step = 1000, loss = 0.031888474, rmse = 0.17859857
Out[191]:
{'global_step': 1000, 'loss': 0.031888474, 'rmse': 0.17859857}
In [0]:
# Funzione di input per predizioni (hack, si potrebbe migliorare!)
def input_fn_predict(data):

  # Path delle immagini
  filenames = tf.constant(data.iloc[:, 0].tolist())

  # Etichette delle immagini
  labels = tf.constant(data.iloc[:, 1:].values)

  # Costruisco il dataset
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.map(_parse_function)
  dataset = dataset.batch(64)
    
  # Costruisco l'iteratore
  f, _ = dataset.make_one_shot_iterator().get_next()
  return f
  
  #for (filename, label) in tfe.Iterator(dataset):
  #  d = _parse_function(filename, label)
In [198]:
# Valutiamo una singola predizione
p = list(single_task_classifier.predict(lambda: input_fn_predict(test_data)))

with tf.Session() as sess:
  imgs = sess.run(input_fn_predict(test_data))
  img_idx = 2
  plt.imshow(imgs["x"][img_idx] / 255)
  plt.scatter(p[img_idx]['predictions'][0] * 40, p[img_idx]['predictions'][1] * 40, 500, marker='x', color='red', linewidth=5)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/cnn_nose/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

Simplify the code with the Head API

In [0]:
# Check the code here: https://www.tensorflow.org/api_docs/python/tf/contrib/estimator/regression_head

def single_head_cnn_model_fn(features, labels, mode):
  
  dense = extract_features(features)
  
  # Predizioni
  predictions = tf.layers.dense(inputs=dense, units=2)

  # Ottimizzatore
  optimizer = tf.train.AdamOptimizer()
  
  # Modello finale
  regression_head = tf.contrib.estimator.regression_head(label_dimension=2)
  return regression_head.create_estimator_spec(features, mode, predictions, labels[:, 2:8:5], lambda x: optimizer.minimize(x, global_step = tf.train.get_or_create_global_step()))
In [82]:
# Create the Estimator
cnn_classifier = tf.estimator.Estimator(
    model_fn=single_head_cnn_model_fn, model_dir="/tmp/cnn_single_head")
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/cnn_single_head', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fae002b7dd8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
In [83]:
# Train the model
cnn_classifier.train(input_fn=lambda: input_fn(train_data), steps=1000)
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 1 into /tmp/cnn_single_head/model.ckpt.
INFO:tensorflow:loss = 19280.736, step = 1
INFO:tensorflow:Saving checkpoints for 100 into /tmp/cnn_single_head/model.ckpt.
INFO:tensorflow:Loss for final step: 14.961566.
Out[83]:
<tensorflow.python.estimator.estimator.Estimator at 0x7fae007316a0>

Multi-task learning with the Head API

In [0]:
def multihead_input_fn(data, is_eval=False):
  features, labels = input_fn(data, is_eval=is_eval)
  return features, {'head_nose': labels[:, 2:8:5], 'head_pose': tf.cast(labels[:, -1] - 1.0, tf.int32)}
In [0]:
def multi_head_cnn_model_fn(features, labels, mode):
  
  dense = extract_features(features)
  
  # Predizioni della rete (per ciascun task)
  predictions_nose = tf.layers.dense(inputs=dense, units=2)
  predictions_pose = tf.layers.dense(inputs=dense, units=5)
  logits = {'head_nose': predictions_nose, 'head_pose': predictions_pose}
  
  # Ottimizzatore
  optimizer = tf.train.AdamOptimizer()
  
  # Definiamo le due head
  regression_head = tf.contrib.estimator.regression_head(name='head_nose', label_dimension=2)
  classification_head = tf.contrib.estimator.multi_class_head(name='head_pose', n_classes=5)
  
  multi_head = tf.contrib.estimator.multi_head([regression_head, classification_head])
  
  return multi_head.create_estimator_spec(features, mode, logits, labels, lambda x: optimizer.minimize(x, global_step = tf.train.get_or_create_global_step()))
In [212]:
# Create the Estimator
multitask_classifier = tf.estimator.Estimator(
    model_fn=multi_head_cnn_model_fn, model_dir="/tmp/cnn_tmp")
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/cnn_tmp', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7faded15b1d0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
In [213]:
# Train the model
multitask_classifier.train(input_fn=lambda: multihead_input_fn(train_data), steps=1000)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 1 into /tmp/cnn_tmp/model.ckpt.
INFO:tensorflow:loss = 116749.75, step = 1
INFO:tensorflow:global_step/sec: 3.18767
INFO:tensorflow:loss = 77.51071, step = 101 (31.373 sec)
INFO:tensorflow:global_step/sec: 3.34664
INFO:tensorflow:loss = 51.99946, step = 201 (29.883 sec)
INFO:tensorflow:global_step/sec: 3.1784
INFO:tensorflow:loss = 89.16292, step = 301 (31.460 sec)
INFO:tensorflow:global_step/sec: 3.21144
INFO:tensorflow:loss = 40.797005, step = 401 (31.138 sec)
INFO:tensorflow:global_step/sec: 3.3237
INFO:tensorflow:loss = 39.19834, step = 501 (30.088 sec)
INFO:tensorflow:global_step/sec: 3.19385
INFO:tensorflow:loss = 62.40456, step = 601 (31.310 sec)
INFO:tensorflow:global_step/sec: 3.27872
INFO:tensorflow:loss = 44.629475, step = 701 (30.499 sec)
INFO:tensorflow:global_step/sec: 3.28512
INFO:tensorflow:loss = 23.16358, step = 801 (30.443 sec)
INFO:tensorflow:global_step/sec: 3.21363
INFO:tensorflow:loss = 47.13872, step = 901 (31.115 sec)
INFO:tensorflow:Saving checkpoints for 1000 into /tmp/cnn_tmp/model.ckpt.
INFO:tensorflow:Loss for final step: 29.839733.
Out[213]:
<tensorflow.python.estimator.estimator.Estimator at 0x7faded15bb70>
In [214]:
  multitask_classifier.evaluate(input_fn=lambda: multihead_input_fn(test_data, is_eval=True))
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-04-10-13:45:26
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/cnn_tmp/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-04-10-13:45:32
INFO:tensorflow:Saving dict for global step 1000: accuracy/head_pose = 0.6290484, average_loss/head_nose = 0.0408225, average_loss/head_pose = 1.1491545, global_step = 1000, loss = 78.43075, loss/head_nose = 5.2026973, loss/head_pose = 73.22804
Out[214]:
{'accuracy/head_pose': 0.6290484,
 'average_loss/head_nose': 0.0408225,
 'average_loss/head_pose': 1.1491545,
 'global_step': 1000,
 'loss': 78.43075,
 'loss/head_nose': 5.2026973,
 'loss/head_pose': 73.22804}
In [216]:
p = list(multitask_classifier.predict(lambda: input_fn_predict(test_data)))
print(p[0])
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/cnn_tmp/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
{('head_nose', 'predictions'): array([0.3372714 , 0.45653588], dtype=float32), ('head_pose', 'logits'): array([-2.9481742 ,  1.0806224 ,  2.5040638 ,  0.86770856, -1.2734869 ],
      dtype=float32), ('head_pose', 'probabilities'): array([0.0029306 , 0.16468   , 0.6836497 , 0.13309863, 0.01564099],
      dtype=float32), ('head_pose', 'class_ids'): array([2]), ('head_pose', 'classes'): array([b'2'], dtype=object)}
In [227]:
with tf.Session() as sess:
  imgs = sess.run(input_fn_predict(test_data))
  
  font = {'family': 'serif',
        'color':  'white',
        'weight': 'bold',
        'size': 16,
        }

  img_idx = 8

  prediction_eye = p[img_idx][(('head_nose', 'predictions'))]
  prediction_pose = p[img_idx][(('head_pose', 'class_ids'))]
  
  plt.imshow(imgs["x"][img_idx] / 255)
  plt.scatter(prediction_eye[0] * 40, prediction_eye[1] * 40, 500, marker='x', color='red', linewidth=5)
  plt.text(5, 3, 'Predicted pose: {}'.format(prediction_pose), fontdict=font)