#!/usr/bin/env python
# coding: utf-8
# # Introduction to BNNs with Larq
#
#
#
# This tutorial demonstrates how to train a simple binarized Convolutional Neural Network (CNN) to classify MNIST digits. This simple network will achieve approximately 98% accuracy on the MNIST test set. This tutorial uses Larq and the [Keras Sequential API](https://www.tensorflow.org/guide/keras), so creating and training our model will require only a few lines of code.
# In[ ]:
pip install larq
# In[1]:
import tensorflow as tf
import larq as lq
# ### Download and prepare the MNIST dataset
# In[2]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))
# Normalize pixel values to be between -1 and 1
train_images, test_images = train_images / 127.5 - 1, test_images / 127.5 - 1
# ### Create the model
#
# The following will create a simple binarized CNN.
#
# The quantization function
# $$
# q(x) = \begin{cases}
# -1 & x < 0 \\\
# 1 & x \geq 0
# \end{cases}
# $$
# is used in the forward pass to binarize the activations and the latent full precision weights. The gradient of this function is zero almost everywhere which prevents the model from learning.
#
# To be able to train the model the gradient is instead estimated using the Straight-Through Estimator (STE)
# (the binarization is essentially replaced by a clipped identity on the backward pass):
# $$
# \frac{\partial q(x)}{\partial x} = \begin{cases}
# 1 & \left|x\right| \leq 1 \\\
# 0 & \left|x\right| > 1
# \end{cases}
# $$
#
# In Larq this can be done by using `input_quantizer="ste_sign"` and `kernel_quantizer="ste_sign"`.
# Additionally, the latent full precision weights are clipped to -1 and 1 using `kernel_constraint="weight_clip"`.
# In[3]:
# All quantized layers except the first will use the same options
kwargs = dict(input_quantizer="ste_sign",
kernel_quantizer="ste_sign",
kernel_constraint="weight_clip")
model = tf.keras.models.Sequential()
# In the first layer we only quantize the weights and not the input
model.add(lq.layers.QuantConv2D(32, (3, 3),
kernel_quantizer="ste_sign",
kernel_constraint="weight_clip",
use_bias=False,
input_shape=(28, 28, 1)))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(tf.keras.layers.Flatten())
model.add(lq.layers.QuantDense(64, use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(lq.layers.QuantDense(10, use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(tf.keras.layers.Activation("softmax"))
# Almost all parameters in the network are binarized, so either -1 or 1. This makes the network extremely fast if it would be deployed on custom BNN hardware.
#
# Here is the complete architecture of our model:
# In[4]:
lq.models.summary(model)
# ### Compile and train the model
#
# Note: This may take a few minutes depending on your system.
# In[10]:
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_images, train_labels, batch_size=64, epochs=6)
test_loss, test_acc = model.evaluate(test_images, test_labels)
# ### Evaluate the model
# In[11]:
print(f"Test accuracy {test_acc * 100:.2f} %")
# As you can see, our simple binarized CNN has achieved a test accuracy of around 98 %. Not bad for a few lines of code!
#
# For information on converting Larq models to an optimized format and using or benchmarking them on Android or ARM devices, have a look at [this guide](https://docs.larq.dev/compute-engine/end_to_end/).