!pip install -U -q imgaug --user
import tensorflow as tf
tf.random.set_seed(42)
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import layers
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
from imgaug import augmenters as iaa
import imgaug as ia
ia.seed(42)
from tensorflow.keras import mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)
For this example, we will be using the CIFAR10 dataset.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
print(f"Total training examples: {len(x_train)}")
print(f"Total test examples: {len(x_test)}")
AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 512
EPOCHS = 100
RandAugment
object¶rand_aug = iaa.RandAugment(n=3, m=7)
def augment(images):
# Input to `augment()` is a TensorFlow tensor which
# is not supported by `imgaug`. This is why we first
# convert it to its `numpy` variant.
images = tf.cast(images, tf.uint8)
return rand_aug(images=images.numpy())
Dataset
objects¶train_ds_rand = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(BATCH_SIZE * 100)
.batch(BATCH_SIZE)
.map(
lambda x, y: (tf.image.resize(x, (72, 72)), y),
num_parallel_calls=AUTO,
)
# The returned output of `tf.py_function` contains an unncessary axis of
# 1-D and we need to remove it.
.map(
lambda x, y: (tf.py_function(augment, [x], [tf.float32])[0], y),
num_parallel_calls=AUTO,
)
.prefetch(AUTO)
)
test_ds = (
tf.data.Dataset.from_tensor_slices((x_test, y_test))
.batch(BATCH_SIZE)
.map(lambda x, y: (tf.image.resize(x, (72, 72)), y),
num_parallel_calls=AUTO)
.prefetch(AUTO)
)
For comparison purposes, let's also define a simple augmentation pipeline consisting of random flips, random rotations, and random zoomings.
simple_aug = tf.keras.Sequential(
[
layers.experimental.preprocessing.Resizing(72, 72),
layers.experimental.preprocessing.RandomFlip("horizontal"),
layers.experimental.preprocessing.RandomRotation(factor=0.02),
layers.experimental.preprocessing.RandomZoom(
height_factor=0.2, width_factor=0.2
),
],
name="data_augmentation",
)
# Now, map the augmentation pipeline to our training dataset
train_ds_simple = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(BATCH_SIZE*100)
.batch(BATCH_SIZE)
.map(lambda x, y: (simple_aug(x), y),
num_parallel_calls=AUTO)
.prefetch(AUTO)
)
sample_images, _ = next(iter(train_ds_rand))
plt.figure(figsize=(10, 10))
for i, image in enumerate(sample_images[:9]):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image.numpy().astype("int"))
plt.axis("off")
simple_aug
¶sample_images, _ = next(iter(train_ds_simple))
plt.figure(figsize=(10, 10))
for i, image in enumerate(sample_images[:9]):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image.numpy().astype("int"))
plt.axis("off")
def get_training_model():
resnet50_v2 = tf.keras.applications.ResNet50V2(
weights=None, include_top=True, input_shape=(72, 72, 3),
classes=10
)
model = tf.keras.Sequential([
layers.Input((72, 72, 3)),
layers.experimental.preprocessing.Rescaling(scale=1./127.5, offset=-1),
resnet50_v2,
layers.Activation("linear", dtype="float32")
])
return model
get_training_model().summary()
# For reproducibility, we first serialize the initialize weights
initial_model = get_training_model()
initial_model.save_weights("initial_weights.h5")
# We also set up an early stopping callback to prevent the models
# from overfitting
es = tf.keras.callbacks.EarlyStopping(
monitor="val_loss", patience=10, restore_best_weights=True
)
rand_aug_model = get_training_model()
rand_aug_model.load_weights("initial_weights.h5")
rand_aug_model.compile(loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"])
rand_aug_model.fit(train_ds_rand,
validation_data=test_ds,
epochs=EPOCHS,
callbacks=[es])
_, test_acc = rand_aug_model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc*100))
rand_aug_model.save("rand_aug_model")
simple_aug
¶simple_aug_model = get_training_model()
simple_aug_model.load_weights("initial_weights.h5")
simple_aug_model.compile(loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"])
simple_aug_model.fit(train_ds_simple,
validation_data=test_ds,
epochs=EPOCHS,
callbacks=[es])
_, test_acc = simple_aug_model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc*100))
simple_aug_model.save("simple_aug_model")
# Load and prepare the CIFAR-10-C dataset
# (If it's not already downloaded, it takes ~10 minutes of time to download)
cifar_10_c = tfds.load("cifar10_corrupted/saturate_5", split="test",
as_supervised=True)
cifar_10_c = (
cifar_10_c
.batch(BATCH_SIZE)
.map(lambda x, y: (tf.image.resize(x, (72, 72)), y),
num_parallel_calls=AUTO))
# Evaluate `rand_aug_model`
_, test_acc = rand_aug_model.evaluate(cifar_10_c, verbose=0)
print("Accuracy with RandAugment on CIFAR-10-C (saturate_5): {:.2f}%".format(test_acc*100))
# Evaluate `simple_aug_model`
_, test_acc = simple_aug_model.evaluate(cifar_10_c, verbose=0)
print("Accuracy with simple_aug on CIFAR-10-C (saturate_5): {:.2f}%".format(test_acc*100))