Setup

In [ ]:
!pip install -U -q imgaug --user
In [ ]:
import tensorflow as tf
tf.random.set_seed(42)

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import layers 

import tensorflow_datasets as tfds
tfds.disable_progress_bar()

from imgaug import augmenters as iaa
import imgaug as ia
ia.seed(42)

from tensorflow.keras import mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)
INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: Tesla V100-SXM2-16GB, compute capability 7.0

Load the CIFAR10 dataset

For this example, we will be using the CIFAR10 dataset.

In [ ]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
print(f"Total training examples: {len(x_train)}")
print(f"Total test examples: {len(x_test)}")
Total training examples: 50000
Total test examples: 10000

Define hyperparameters

In [ ]:
AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 512 
EPOCHS = 100

Initialize RandAugment object

In [ ]:
rand_aug = iaa.RandAugment(n=3, m=7)

def augment(images):
    # Input to `augment()` is a TensorFlow tensor which
    # is not supported by `imgaug`. This is why we first
    # convert it to its `numpy` variant.
    images = tf.cast(images, tf.uint8)
    return rand_aug(images=images.numpy())

Create TensorFlow Dataset objects

In [ ]:
train_ds_rand = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(BATCH_SIZE * 100)
    .batch(BATCH_SIZE)
    .map(
        lambda x, y: (tf.image.resize(x, (72, 72)), y),
        num_parallel_calls=AUTO,
    )
    # The returned output of `tf.py_function` contains an unncessary axis of
    # 1-D and we need to remove it.
    .map(
        lambda x, y: (tf.py_function(augment, [x], [tf.float32])[0], y),
        num_parallel_calls=AUTO,
    )
    .prefetch(AUTO)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .batch(BATCH_SIZE)
    .map(lambda x, y: (tf.image.resize(x, (72, 72)), y),
        num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

For comparison purposes, let's also define a simple augmentation pipeline consisting of random flips, random rotations, and random zoomings.

In [ ]:
simple_aug = tf.keras.Sequential(
    [
        layers.experimental.preprocessing.Resizing(72, 72),
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(factor=0.02),
        layers.experimental.preprocessing.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)

# Now, map the augmentation pipeline to our training dataset
train_ds_simple = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(BATCH_SIZE*100)
    .batch(BATCH_SIZE)
    .map(lambda x, y: (simple_aug(x), y),
                            num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

Visualize the dataset augmented with RandAugment

In [ ]:
sample_images, _ = next(iter(train_ds_rand))
plt.figure(figsize=(10, 10))
for i, image in enumerate(sample_images[:9]):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image.numpy().astype("int"))
    plt.axis("off")

Visualize the dataset augmented with simple_aug

In [ ]:
sample_images, _ = next(iter(train_ds_simple))
plt.figure(figsize=(10, 10))
for i, image in enumerate(sample_images[:9]):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image.numpy().astype("int"))
    plt.axis("off")

Define a model building utility function

In [ ]:
def get_training_model():
    resnet50_v2 = tf.keras.applications.ResNet50V2(
        weights=None, include_top=True, input_shape=(72, 72, 3),
        classes=10
    )
    model = tf.keras.Sequential([
        layers.Input((72, 72, 3)),
        layers.experimental.preprocessing.Rescaling(scale=1./127.5, offset=-1),
        resnet50_v2,
        layers.Activation("linear", dtype="float32")
    ])
    return model

get_training_model().summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
rescaling (Rescaling)        (None, 72, 72, 3)         0         
_________________________________________________________________
resnet50v2 (Functional)      (None, 10)                23585290  
_________________________________________________________________
activation (Activation)      (None, 10)                0         
=================================================================
Total params: 23,585,290
Trainable params: 23,539,850
Non-trainable params: 45,440
_________________________________________________________________
In [ ]:
# For reproducibility, we first serialize the initialize weights
initial_model = get_training_model()
initial_model.save_weights("initial_weights.h5")
In [ ]:
# We also set up an early stopping callback to prevent the models
# from overfitting
es = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss", patience=10, restore_best_weights=True
)

1. Train model with RandAugment

In [ ]:
rand_aug_model = get_training_model()
rand_aug_model.load_weights("initial_weights.h5")
rand_aug_model.compile(loss="sparse_categorical_crossentropy", 
                      optimizer="adam", 
                      metrics=["accuracy"])
rand_aug_model.fit(train_ds_rand,
        validation_data=test_ds,
        epochs=EPOCHS,
        callbacks=[es])
_, test_acc = rand_aug_model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc*100))
Epoch 1/100
98/98 [==============================] - 85s 705ms/step - loss: 3.1592 - accuracy: 0.1834 - val_loss: nan - val_accuracy: 0.1045
Epoch 2/100
98/98 [==============================] - 69s 653ms/step - loss: 1.7934 - accuracy: 0.3600 - val_loss: 7.4934 - val_accuracy: 0.1491
Epoch 3/100
98/98 [==============================] - 68s 646ms/step - loss: 1.8542 - accuracy: 0.3434 - val_loss: nan - val_accuracy: 0.1157
Epoch 4/100
98/98 [==============================] - 68s 640ms/step - loss: 1.7277 - accuracy: 0.3930 - val_loss: 4.5487 - val_accuracy: 0.2414
Epoch 5/100
98/98 [==============================] - 71s 626ms/step - loss: 1.6077 - accuracy: 0.4401 - val_loss: 1.5698 - val_accuracy: 0.4525
Epoch 6/100
98/98 [==============================] - 68s 643ms/step - loss: 1.5469 - accuracy: 0.4623 - val_loss: 2.3275 - val_accuracy: 0.2820
Epoch 7/100
98/98 [==============================] - 69s 645ms/step - loss: 1.5825 - accuracy: 0.4494 - val_loss: 1.4070 - val_accuracy: 0.5059
Epoch 8/100
98/98 [==============================] - 68s 636ms/step - loss: 1.3531 - accuracy: 0.5279 - val_loss: 1.5508 - val_accuracy: 0.4852
Epoch 9/100
98/98 [==============================] - 69s 644ms/step - loss: 1.3057 - accuracy: 0.5373 - val_loss: 1.2720 - val_accuracy: 0.5548
Epoch 10/100
98/98 [==============================] - 67s 634ms/step - loss: 1.2089 - accuracy: 0.5755 - val_loss: 1.2662 - val_accuracy: 0.5588
Epoch 11/100
98/98 [==============================] - 68s 645ms/step - loss: 1.1647 - accuracy: 0.5924 - val_loss: 1.0406 - val_accuracy: 0.6329
Epoch 12/100
98/98 [==============================] - 67s 632ms/step - loss: 1.0877 - accuracy: 0.6202 - val_loss: 0.9969 - val_accuracy: 0.6526
Epoch 13/100
98/98 [==============================] - 66s 622ms/step - loss: 1.0120 - accuracy: 0.6445 - val_loss: 1.0054 - val_accuracy: 0.6457
Epoch 14/100
98/98 [==============================] - 67s 629ms/step - loss: 0.9721 - accuracy: 0.6620 - val_loss: 0.9890 - val_accuracy: 0.6531
Epoch 15/100
98/98 [==============================] - 67s 627ms/step - loss: 1.0118 - accuracy: 0.6485 - val_loss: 0.8240 - val_accuracy: 0.7164
Epoch 16/100
98/98 [==============================] - 67s 634ms/step - loss: 0.8661 - accuracy: 0.6961 - val_loss: 0.8105 - val_accuracy: 0.7148
Epoch 17/100
98/98 [==============================] - 67s 631ms/step - loss: 0.8365 - accuracy: 0.7076 - val_loss: 0.8673 - val_accuracy: 0.7106
Epoch 18/100
98/98 [==============================] - 67s 638ms/step - loss: 0.7939 - accuracy: 0.7200 - val_loss: 0.9348 - val_accuracy: 0.7002
Epoch 19/100
98/98 [==============================] - 70s 659ms/step - loss: 0.7548 - accuracy: 0.7371 - val_loss: 0.9441 - val_accuracy: 0.7137
Epoch 20/100
98/98 [==============================] - 72s 684ms/step - loss: 0.7254 - accuracy: 0.7493 - val_loss: 0.7852 - val_accuracy: 0.7308
Epoch 21/100
98/98 [==============================] - 67s 637ms/step - loss: 0.6884 - accuracy: 0.7604 - val_loss: 0.7827 - val_accuracy: 0.7539
Epoch 22/100
98/98 [==============================] - 66s 627ms/step - loss: 0.6703 - accuracy: 0.7661 - val_loss: 0.6435 - val_accuracy: 0.7755
Epoch 23/100
98/98 [==============================] - 67s 637ms/step - loss: 0.6349 - accuracy: 0.7790 - val_loss: 0.6341 - val_accuracy: 0.7806
Epoch 24/100
98/98 [==============================] - 67s 633ms/step - loss: 0.6040 - accuracy: 0.7882 - val_loss: 1.1893 - val_accuracy: 0.7360
Epoch 25/100
98/98 [==============================] - 71s 678ms/step - loss: 0.5804 - accuracy: 0.7971 - val_loss: 0.6380 - val_accuracy: 0.7992
Epoch 26/100
98/98 [==============================] - 70s 662ms/step - loss: 0.5789 - accuracy: 0.7970 - val_loss: 0.5683 - val_accuracy: 0.8024
Epoch 27/100
98/98 [==============================] - 67s 630ms/step - loss: 0.5501 - accuracy: 0.8078 - val_loss: 0.6145 - val_accuracy: 0.7967
Epoch 28/100
98/98 [==============================] - 66s 622ms/step - loss: 0.5329 - accuracy: 0.8135 - val_loss: 0.5557 - val_accuracy: 0.8100
Epoch 29/100
98/98 [==============================] - 68s 641ms/step - loss: 0.5112 - accuracy: 0.8227 - val_loss: 0.5435 - val_accuracy: 0.8206
Epoch 30/100
98/98 [==============================] - 66s 627ms/step - loss: 0.4964 - accuracy: 0.8272 - val_loss: 0.7384 - val_accuracy: 0.7746
Epoch 31/100
98/98 [==============================] - 67s 629ms/step - loss: 0.4890 - accuracy: 0.8289 - val_loss: 0.6565 - val_accuracy: 0.7866
Epoch 32/100
98/98 [==============================] - 68s 641ms/step - loss: 0.4700 - accuracy: 0.8383 - val_loss: 0.7538 - val_accuracy: 0.7881
Epoch 33/100
98/98 [==============================] - 67s 633ms/step - loss: 0.4668 - accuracy: 0.8368 - val_loss: 0.5000 - val_accuracy: 0.8297
Epoch 34/100
98/98 [==============================] - 67s 637ms/step - loss: 0.4425 - accuracy: 0.8480 - val_loss: 0.5721 - val_accuracy: 0.8093
Epoch 35/100
98/98 [==============================] - 68s 641ms/step - loss: 0.4215 - accuracy: 0.8531 - val_loss: 0.6007 - val_accuracy: 0.8047
Epoch 36/100
98/98 [==============================] - 67s 635ms/step - loss: 0.4156 - accuracy: 0.8553 - val_loss: 0.5727 - val_accuracy: 0.8144
Epoch 37/100
98/98 [==============================] - 67s 635ms/step - loss: 0.4091 - accuracy: 0.8592 - val_loss: 0.5009 - val_accuracy: 0.8340
Epoch 38/100
98/98 [==============================] - 67s 635ms/step - loss: 0.4002 - accuracy: 0.8595 - val_loss: 0.5630 - val_accuracy: 0.8236
Epoch 39/100
98/98 [==============================] - 68s 641ms/step - loss: 0.3822 - accuracy: 0.8673 - val_loss: 0.4742 - val_accuracy: 0.8419
Epoch 40/100
98/98 [==============================] - 70s 664ms/step - loss: 0.3588 - accuracy: 0.8766 - val_loss: 0.4799 - val_accuracy: 0.8422
Epoch 41/100
98/98 [==============================] - 67s 633ms/step - loss: 0.3549 - accuracy: 0.8762 - val_loss: 0.4908 - val_accuracy: 0.8396
Epoch 42/100
98/98 [==============================] - 70s 673ms/step - loss: 0.3447 - accuracy: 0.8787 - val_loss: 0.5453 - val_accuracy: 0.8311
Epoch 43/100
98/98 [==============================] - 68s 644ms/step - loss: 0.3497 - accuracy: 0.8782 - val_loss: 0.5162 - val_accuracy: 0.8385
Epoch 44/100
98/98 [==============================] - 67s 638ms/step - loss: 0.3273 - accuracy: 0.8856 - val_loss: 0.5166 - val_accuracy: 0.8308
Epoch 45/100
98/98 [==============================] - 68s 645ms/step - loss: 0.3143 - accuracy: 0.8917 - val_loss: 0.4778 - val_accuracy: 0.8448
Epoch 46/100
98/98 [==============================] - 67s 633ms/step - loss: 0.3052 - accuracy: 0.8931 - val_loss: 0.5501 - val_accuracy: 0.8382
Epoch 47/100
98/98 [==============================] - 67s 634ms/step - loss: 0.3007 - accuracy: 0.8947 - val_loss: 0.4818 - val_accuracy: 0.8509
Epoch 48/100
98/98 [==============================] - 68s 643ms/step - loss: 0.2832 - accuracy: 0.8985 - val_loss: 0.5162 - val_accuracy: 0.8439
Epoch 49/100
98/98 [==============================] - 68s 639ms/step - loss: 0.2729 - accuracy: 0.9048 - val_loss: 0.5436 - val_accuracy: 0.8367
20/20 [==============================] - 1s 30ms/step - loss: 0.4742 - accuracy: 0.8419
Test accuracy: 84.19%
In [ ]:
rand_aug_model.save("rand_aug_model")
INFO:tensorflow:Assets written to: rand_aug_model/assets

2. Train model with simple_aug

In [ ]:
simple_aug_model = get_training_model()
simple_aug_model.load_weights("initial_weights.h5")
simple_aug_model.compile(loss="sparse_categorical_crossentropy", 
                      optimizer="adam", 
                      metrics=["accuracy"])
simple_aug_model.fit(train_ds_simple,
        validation_data=test_ds,
        epochs=EPOCHS,
        callbacks=[es])
_, test_acc = simple_aug_model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc*100))
Epoch 1/100
98/98 [==============================] - 28s 202ms/step - loss: 2.3730 - accuracy: 0.2644 - val_loss: 7.0031 - val_accuracy: 0.1286
Epoch 2/100
98/98 [==============================] - 19s 183ms/step - loss: 1.2847 - accuracy: 0.5453 - val_loss: 1.7744 - val_accuracy: 0.4162
Epoch 3/100
98/98 [==============================] - 19s 181ms/step - loss: 1.0809 - accuracy: 0.6203 - val_loss: nan - val_accuracy: 0.0737
Epoch 4/100
98/98 [==============================] - 19s 183ms/step - loss: 0.9696 - accuracy: 0.6610 - val_loss: 1.0999 - val_accuracy: 0.6323
Epoch 5/100
98/98 [==============================] - 19s 183ms/step - loss: 0.9832 - accuracy: 0.6644 - val_loss: nan - val_accuracy: 0.1006
Epoch 6/100
98/98 [==============================] - 19s 183ms/step - loss: 1.0794 - accuracy: 0.6355 - val_loss: 12.5741 - val_accuracy: 0.1205
Epoch 7/100
98/98 [==============================] - 19s 182ms/step - loss: 0.8914 - accuracy: 0.6890 - val_loss: 0.8805 - val_accuracy: 0.7018
Epoch 8/100
98/98 [==============================] - 19s 181ms/step - loss: 0.6830 - accuracy: 0.7598 - val_loss: 0.7834 - val_accuracy: 0.7346
Epoch 9/100
98/98 [==============================] - 19s 182ms/step - loss: 0.5906 - accuracy: 0.7949 - val_loss: 0.7656 - val_accuracy: 0.7445
Epoch 10/100
98/98 [==============================] - 19s 184ms/step - loss: 0.5317 - accuracy: 0.8146 - val_loss: 0.7136 - val_accuracy: 0.7549
Epoch 11/100
98/98 [==============================] - 19s 182ms/step - loss: 0.4830 - accuracy: 0.8303 - val_loss: 0.7174 - val_accuracy: 0.7580
Epoch 12/100
98/98 [==============================] - 19s 183ms/step - loss: 0.4508 - accuracy: 0.8427 - val_loss: 0.6619 - val_accuracy: 0.7824
Epoch 13/100
98/98 [==============================] - 19s 181ms/step - loss: 0.4086 - accuracy: 0.8557 - val_loss: 0.7537 - val_accuracy: 0.7533
Epoch 14/100
98/98 [==============================] - 19s 183ms/step - loss: 0.3770 - accuracy: 0.8678 - val_loss: 0.6286 - val_accuracy: 0.7898
Epoch 15/100
98/98 [==============================] - 19s 182ms/step - loss: 0.3477 - accuracy: 0.8779 - val_loss: 0.6000 - val_accuracy: 0.8014
Epoch 16/100
98/98 [==============================] - 19s 184ms/step - loss: 0.3211 - accuracy: 0.8884 - val_loss: 0.6156 - val_accuracy: 0.8045
Epoch 17/100
98/98 [==============================] - 19s 183ms/step - loss: 0.2923 - accuracy: 0.8966 - val_loss: 0.8128 - val_accuracy: 0.7648
Epoch 18/100
98/98 [==============================] - 19s 182ms/step - loss: 0.2739 - accuracy: 0.9036 - val_loss: 0.6538 - val_accuracy: 0.7948
Epoch 19/100
98/98 [==============================] - 19s 185ms/step - loss: 0.2517 - accuracy: 0.9121 - val_loss: 0.6547 - val_accuracy: 0.8092
Epoch 20/100
98/98 [==============================] - 20s 186ms/step - loss: 0.2416 - accuracy: 0.9174 - val_loss: 0.6659 - val_accuracy: 0.8075
Epoch 21/100
98/98 [==============================] - 19s 184ms/step - loss: 0.2173 - accuracy: 0.9243 - val_loss: 0.6265 - val_accuracy: 0.8131
Epoch 22/100
98/98 [==============================] - 19s 184ms/step - loss: 0.2059 - accuracy: 0.9285 - val_loss: 0.6124 - val_accuracy: 0.8186
Epoch 23/100
98/98 [==============================] - 19s 185ms/step - loss: 0.1905 - accuracy: 0.9334 - val_loss: 0.6885 - val_accuracy: 0.8131
Epoch 24/100
98/98 [==============================] - 19s 185ms/step - loss: 0.1830 - accuracy: 0.9359 - val_loss: 0.6368 - val_accuracy: 0.8256
Epoch 25/100
98/98 [==============================] - 19s 182ms/step - loss: 0.1661 - accuracy: 0.9421 - val_loss: 0.7623 - val_accuracy: 0.8005
20/20 [==============================] - 1s 31ms/step - loss: 0.6000 - accuracy: 0.8014
Test accuracy: 80.14%
In [ ]:
simple_aug_model.save("simple_aug_model")
INFO:tensorflow:Assets written to: simple_aug_model/assets

Load the CIFAR-10-C dataset and evaluate performance

In [ ]:
# Load and prepare the CIFAR-10-C dataset
# (If it's not already downloaded, it takes ~10 minutes of time to download)
cifar_10_c = tfds.load("cifar10_corrupted/saturate_5", split="test", 
                      as_supervised=True)
cifar_10_c = (
    cifar_10_c
    .batch(BATCH_SIZE)
    .map(lambda x, y: (tf.image.resize(x, (72, 72)), y),
        num_parallel_calls=AUTO))
Downloading and preparing dataset 2.72 GiB (download: 2.72 GiB, generated: Unknown size, total: 2.72 GiB) to /home/jupyter/tensorflow_datasets/cifar10_corrupted/saturate_5/1.0.0...
Dataset cifar10_corrupted downloaded and prepared to /home/jupyter/tensorflow_datasets/cifar10_corrupted/saturate_5/1.0.0. Subsequent calls will reuse this data.
In [ ]:
# Evaluate `rand_aug_model`
_, test_acc = rand_aug_model.evaluate(cifar_10_c, verbose=0)
print("Accuracy with RandAugment on CIFAR-10-C (saturate_5): {:.2f}%".format(test_acc*100))

# Evaluate `simple_aug_model`
_, test_acc = simple_aug_model.evaluate(cifar_10_c, verbose=0)
print("Accuracy with simple_aug on CIFAR-10-C (saturate_5): {:.2f}%".format(test_acc*100))
Accuracy with RandAugment on CIFAR-10-C (saturate_5): 76.64%
Accuracy with simple_aug on CIFAR-10-C (saturate_5): 64.80%