Here is a very quick implemention and walkthrough to show using TPUs with Keras in Colab.
If you have any questions or suggestions to make it better please let me know.
import numpy as np
import tensorflow as tf
import time
import os
import tensorflow.keras
from tensorflow.keras.datasets import mnist, fashion_mnist
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Dropout, Flatten,Input
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras import backend as K
print(tf.__version__)
print(tf.keras.__version__)
1.11.0-rc2 2.1.6-tf
First, test if you have TPU set up.
Run the Cell below.
If no TPU is found, press "Runtime" (in the menu at the top) and choose "Change Runtime Type" to TPU.
The TPU_ADDRESS
variable will be needed to pass into the distribution strategy.
try:
device_name = os.environ['COLAB_TPU_ADDR']
TPU_ADDRESS = 'grpc://' + device_name
print('Found TPU at: {}'.format(TPU_ADDRESS))
except KeyError:
print('TPU not found')
Found TPU at: grpc://10.114.111.10:8470
batch_size = 1024
num_classes = 10
epochs = 5
learning_rate = 0.001
# input image dimensions
img_rows, img_cols = 28, 28
# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
x_train shape: (60000, 28, 28, 1) 60000 train samples 10000 test samples
# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
tf.data
¶You need to make sure you have drop_remainder = True
as TPUs need to have a fixed shape.
def train_input_fn(batch_size=1024):
# convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train))
# shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size, drop_remainder=True)
# return the dataset.
return dataset
def test_input_fn(batch_size=1024):
# convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((x_test,y_test))
# shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size, drop_remainder=True)
# return the dataset.
return dataset
You must pass in an input shape and batch size as TPUs (and XLA) require fixed shapes.
The rest of the model is just a simple CNN.
Inp = tf.keras.Input(
name='input', shape=input_shape, batch_size=batch_size, dtype=tf.float32)
x = Conv2D(32, kernel_size=(3, 3), activation='relu',name = 'Conv_01')(Inp)
x = MaxPooling2D(pool_size=(2, 2),name = 'MaxPool_01')(x)
x = Conv2D(64, (3, 3), activation='relu',name = 'Conv_02')(x)
x = MaxPooling2D(pool_size=(2, 2),name = 'MaxPool_02')(x)
x = Conv2D(64, (3, 3), activation='relu',name = 'Conv_03')(x)
x = Flatten(name = 'Flatten_01')(x)
x = Dense(64, activation='relu',name = 'Dense_01')(x)
x = Dropout(0.5,name = 'Dropout_02')(x)
output = Dense(num_classes, activation='softmax',name = 'Dense_02')(x)
model = tf.keras.Model(inputs=[Inp], outputs=[output])
# use a tf optimizer rather than a Keras one for now
opt = tf.train.AdamOptimizer(learning_rate)
model.compile(
optimizer=opt,
loss='categorical_crossentropy',
metrics=['acc'])
tf.contrib.tpu.keras_to_tpu_model
will eventually go away and you will pass it into the model.compile
as a distribution strategy, but for 1.11 this works.
We can see this is a TPU v2 with 8 cores.
For batching you want to have a batch of 128 per core so 1024 overall.
You could also use 128, 256, 512 etc.
tpu_model = tf.contrib.tpu.keras_to_tpu_model(
model,
strategy=tf.contrib.tpu.TPUDistributionStrategy(
tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)))
INFO:tensorflow:Querying Tensorflow master (b'grpc://10.114.111.10:8470') for TPU system metadata. INFO:tensorflow:Found TPU system: INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 973931917537708864) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 8792028991883212283) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_GPU:0, XLA_GPU, 17179869184, 10595085297325393161) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 10139671714968909828) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 10491071598227653110) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 3213028352983874138) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 13713210220232872762) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 16117693853034682668) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 3592681710289544177) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 12525050454546375568) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 17588780763802917777) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 1127662344348348349) WARNING:tensorflow:tpu_model (from tensorflow.contrib.tpu.python.tpu.keras_support) is experimental and may change or be removed at any time, and without warning. INFO:tensorflow:Connecting to: b'grpc://10.114.111.10:8470'
tpu_model.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input (InputLayer) (1024, 28, 28, 1) 0 _________________________________________________________________ Conv_01 (Conv2D) (1024, 26, 26, 32) 320 _________________________________________________________________ MaxPool_01 (MaxPooling2D) (1024, 13, 13, 32) 0 _________________________________________________________________ Conv_02 (Conv2D) (1024, 11, 11, 64) 18496 _________________________________________________________________ MaxPool_02 (MaxPooling2D) (1024, 5, 5, 64) 0 _________________________________________________________________ Conv_03 (Conv2D) (1024, 3, 3, 64) 36928 _________________________________________________________________ Flatten_01 (Flatten) (1024, 576) 0 _________________________________________________________________ Dense_01 (Dense) (1024, 64) 36928 _________________________________________________________________ Dropout_02 (Dropout) (1024, 64) 0 _________________________________________________________________ Dense_02 (Dense) (1024, 10) 650 ================================================================= Total params: 93,322 Trainable params: 93,322 Non-trainable params: 0 _________________________________________________________________
tf.data pipeline
¶Obviously training MNIST on a TPU is a bit overkill and the TPU barely gets a chance to warm up. ^-^
tpu_model.fit(
train_input_fn,
steps_per_epoch = 60,
epochs=10,
)
Epoch 1/10 INFO:tensorflow:New input shapes; (re-)compiling: mode=train, [TensorSpec(shape=(1024, 28, 28, 1), dtype=tf.float32, name=None), TensorSpec(shape=(1024, 10), dtype=tf.float32, name=None)] INFO:tensorflow:Overriding default placeholder. INFO:tensorflow:Remapping placeholder for input INFO:tensorflow:Started compiling INFO:tensorflow:Finished compiling. Time elapsed: 1.6537692546844482 secs INFO:tensorflow:Setting weights on TPU model. 60/60 [==============================] - 6s 104ms/step - loss: 0.9355 - acc: 0.7056 Epoch 2/10 60/60 [==============================] - 3s 44ms/step - loss: 0.2260 - acc: 0.9349 Epoch 3/10 60/60 [==============================] - 3s 46ms/step - loss: 0.1372 - acc: 0.9606 Epoch 4/10 60/60 [==============================] - 3s 48ms/step - loss: 0.1055 - acc: 0.9702 Epoch 5/10 60/60 [==============================] - 3s 48ms/step - loss: 0.0838 - acc: 0.9760 Epoch 6/10 60/60 [==============================] - 3s 48ms/step - loss: 0.0696 - acc: 0.9799 Epoch 7/10 60/60 [==============================] - 3s 44ms/step - loss: 0.0623 - acc: 0.9820 Epoch 8/10 60/60 [==============================] - 3s 44ms/step - loss: 0.0576 - acc: 0.9838 Epoch 9/10 60/60 [==============================] - 3s 43ms/step - loss: 0.0492 - acc: 0.9858 Epoch 10/10 60/60 [==============================] - 3s 43ms/step - loss: 0.0449 - acc: 0.9875
tpu_model.save_weights('./MNIST_TPU_1024.h5', overwrite=True)
INFO:tensorflow:Copying TPU weights to the CPU
Evaluate model.
tpu_model.evaluate(test_input_fn, steps = 100)
INFO:tensorflow:New input shapes; (re-)compiling: mode=eval, [TensorSpec(shape=(1024, 28, 28, 1), dtype=tf.float32, name=None), TensorSpec(shape=(1024, 10), dtype=tf.float32, name=None)] INFO:tensorflow:Overriding default placeholder. INFO:tensorflow:Remapping placeholder for input INFO:tensorflow:Started compiling INFO:tensorflow:Finished compiling. Time elapsed: 0.9941656589508057 secs 100/100 [==============================] - 7s 65ms/step
[0.0268026649922831, 0.991123046875]
tf.data
is much slower!¶tpu_model.fit(x_train, y_train, epochs=1)
Epoch 1/1 INFO:tensorflow:New input shapes; (re-)compiling: mode=train, [TensorSpec(shape=(4, 28, 28, 1), dtype=tf.float32, name='input0'), TensorSpec(shape=(4, 10), dtype=tf.float32, name='Dense_02_target_10')] INFO:tensorflow:Overriding default placeholder. INFO:tensorflow:Remapping placeholder for input INFO:tensorflow:Started compiling INFO:tensorflow:Finished compiling. Time elapsed: 1.0541026592254639 secs 60000/60000 [==============================] - 58s 964us/step - loss: 0.0991 - acc: 0.9708
Note:
This notebook was adapted from the Jupyter notebook used for the demo during the talk, "Get training in Keras on TPUs for free" at Singapore TensorFlow and Deep Learning group meetup on 2018-09-28 GMT+8. Thanks to Sam Witteveen.
Slides: https://www.dropbox.com/s/jg7j07unw94wbom/TensorFlow%20Keras%20Colab%20TPUs.pdf?dl=0