#!/usr/bin/env python # coding: utf-8 # # Introduction to BNNs with Larq # # # # This tutorial demonstrates how to train a simple binarized Convolutional Neural Network (CNN) to classify MNIST digits. This simple network will achieve approximately 98% accuracy on the MNIST test set. This tutorial uses Larq and the [Keras Sequential API](https://www.tensorflow.org/guide/keras), so creating and training our model will require only a few lines of code. # In[ ]: pip install larq # In[1]: import tensorflow as tf import larq as lq # ### Download and prepare the MNIST dataset # In[2]: (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() train_images = train_images.reshape((60000, 28, 28, 1)) test_images = test_images.reshape((10000, 28, 28, 1)) # Normalize pixel values to be between -1 and 1 train_images, test_images = train_images / 127.5 - 1, test_images / 127.5 - 1 # ### Create the model # # The following will create a simple binarized CNN. # # The quantization function # $$ # q(x) = \begin{cases} # -1 & x < 0 \\\ # 1 & x \geq 0 # \end{cases} # $$ # is used in the forward pass to binarize the activations and the latent full precision weights. The gradient of this function is zero almost everywhere which prevents the model from learning. # # To be able to train the model the gradient is instead estimated using the Straight-Through Estimator (STE) # (the binarization is essentially replaced by a clipped identity on the backward pass): # $$ # \frac{\partial q(x)}{\partial x} = \begin{cases} # 1 & \left|x\right| \leq 1 \\\ # 0 & \left|x\right| > 1 # \end{cases} # $$ # # In Larq this can be done by using `input_quantizer="ste_sign"` and `kernel_quantizer="ste_sign"`. # Additionally, the latent full precision weights are clipped to -1 and 1 using `kernel_constraint="weight_clip"`. # In[3]: # All quantized layers except the first will use the same options kwargs = dict(input_quantizer="ste_sign", kernel_quantizer="ste_sign", kernel_constraint="weight_clip") model = tf.keras.models.Sequential() # In the first layer we only quantize the weights and not the input model.add(lq.layers.QuantConv2D(32, (3, 3), kernel_quantizer="ste_sign", kernel_constraint="weight_clip", use_bias=False, input_shape=(28, 28, 1))) model.add(tf.keras.layers.MaxPooling2D((2, 2))) model.add(tf.keras.layers.BatchNormalization(scale=False)) model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs)) model.add(tf.keras.layers.MaxPooling2D((2, 2))) model.add(tf.keras.layers.BatchNormalization(scale=False)) model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs)) model.add(tf.keras.layers.BatchNormalization(scale=False)) model.add(tf.keras.layers.Flatten()) model.add(lq.layers.QuantDense(64, use_bias=False, **kwargs)) model.add(tf.keras.layers.BatchNormalization(scale=False)) model.add(lq.layers.QuantDense(10, use_bias=False, **kwargs)) model.add(tf.keras.layers.BatchNormalization(scale=False)) model.add(tf.keras.layers.Activation("softmax")) # Almost all parameters in the network are binarized, so either -1 or 1. This makes the network extremely fast if it would be deployed on custom BNN hardware. # # Here is the complete architecture of our model: # In[4]: lq.models.summary(model) # ### Compile and train the model # # Note: This may take a few minutes depending on your system. # In[10]: model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(train_images, train_labels, batch_size=64, epochs=6) test_loss, test_acc = model.evaluate(test_images, test_labels) # ### Evaluate the model # In[11]: print(f"Test accuracy {test_acc * 100:.2f} %") # As you can see, our simple binarized CNN has achieved a test accuracy of around 98 %. Not bad for a few lines of code! # # For information on converting Larq models to an optimized format and using or benchmarking them on Android or ARM devices, have a look at [this guide](https://docs.larq.dev/compute-engine/end_to_end/).