Before diving into the attack, let's first prepare a classification model. We utilize a script from the examples
folders.
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPool2D
from art.attacks.evasion import FeatureAdversariesTensorFlowV2
from art.estimators.classification import TensorFlowV2Classifier
from art.utils import load_mnist
# Step 1: Load the MNIST dataset
(x_train, y_train), (x_test, y_test), min_pixel_value, max_pixel_value = load_mnist()
# Step 1a: Cast to np.float32
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)
# Step 2: Create the model
model = Sequential()
model.add(Conv2D(filters=4, kernel_size=5, activation="relu"))
model.add(MaxPool2D(pool_size=(2, 2), padding="valid", data_format=None))
model.add(Conv2D(filters=10, kernel_size=5, activation="relu"))
model.add(MaxPool2D(pool_size=(2, 2), padding="valid", data_format=None))
model.add(Flatten())
model.add(Dense(100))
model.add(Dense(10))
# Step 2a: Define the loss function and optimizer
loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
# Step 3: Create the ART classifier
classifier = TensorFlowV2Classifier(
model=model,
loss_object=loss_object,
optimizer=optimizer,
nb_classes=10,
input_shape=(28, 28, 1),
clip_values=(0, 1),
)
# Step 4: Train the ART classifier
classifier.fit(x_train, y_train, batch_size=64, nb_epochs=3)
# Step 5: Evaluate the ART classifier on benign test examples
predictions = classifier.predict(x_test)
accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)
print("Accuracy on benign test examples: {}%".format(accuracy * 100))
Accuracy on benign test examples: 96.77%
# Step 5: prepare a batch of source and guide images
valid = np.argmax(y_test, axis=1)[:100] != np.argmax(y_test, axis=1)[100:200]
source = x_test[:100][valid][:32]
guide = x_test[100:200][valid][:32]
# Step 6: Generate adversarial test examples
attack = FeatureAdversariesTensorFlowV2(
classifier,
layer=-2,
delta=45/255,
optimizer=None,
step_size=1/255,
max_iter=100,
)
x_test_adv = attack.generate(source, guide)
# Step 7: Evaluate the ART classifier on adversarial test examples
predictions = classifier.predict(x_test_adv)
accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test[:100][valid][:32], axis=1)) / len(y_test[:100][valid][:32])
dim = tuple(range(1, len(source.shape)))
pert = np.mean(np.amax(np.abs(source - x_test_adv), axis=dim))
print("Accuracy on adversarial test batch: {}%".format(accuracy * 100))
print("Average perturbation: {}%".format(pert))
Feature Adversaries TensorFlow v2: 0%| | 0/100 [00:00<?, ?it/s]
Accuracy on adversarial test batch: 0.0% Average perturbation: 0.17647060751914978%
# Step 8: Inspect results
# orig 7, guide 6
plt.imshow(x_test_adv[0,...].squeeze())
<matplotlib.image.AxesImage at 0x7fe400651d30>
# orig 1, guide 5
plt.imshow(x_test_adv[2,...].squeeze())
<matplotlib.image.AxesImage at 0x7fe40053a4c0>
# orig 4, guide 9
plt.imshow(x_test_adv[4,...].squeeze())
<matplotlib.image.AxesImage at 0x7fe400522700>
# orig 4, guide 2
plt.imshow(x_test_adv[6,...].squeeze())
<matplotlib.image.AxesImage at 0x7fe400488250>
# orig 5, guide 9
plt.imshow(x_test_adv[8,...].squeeze())
<matplotlib.image.AxesImage at 0x7fe400466850>
# orig 0, guide 8
plt.imshow(x_test_adv[10,...].squeeze())
<matplotlib.image.AxesImage at 0x7fe4003c1e50>
This variant approximates the original hard constraint problem of the paper. Any available TensorFlow v2 optimizer can be used. We will use Adam as a good default one.
# Step 6: Generate adversarial test examples
attack = FeatureAdversariesTensorFlowV2(
classifier,
layer=-2,
delta=45/255,
optimizer=tf.keras.optimizers.Adam,
optimizer_kwargs={"learning_rate": 0.01},
lambda_=1.0,
max_iter=100,
random_start=True,
)
x_test_adv = attack.generate(source, guide)
# Step 7: Evaluate the ART classifier on adversarial test examples
predictions = classifier.predict(x_test_adv)
accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test[:100][valid][:32], axis=1)) / len(y_test[:100][valid][:32])
dim = tuple(range(1, len(source.shape)))
pert = np.mean(np.amax(np.abs(source - x_test_adv), axis=dim))
print("Accuracy on adversarial test batch: {}%".format(accuracy * 100))
print("Average perturbation: {}%".format(pert))
Feature Adversaries TensorFlow v2: 0%| | 0/100 [00:00<?, ?it/s]
Accuracy on adversarial test batch: 0.0% Average perturbation: 0.17647060751914978%
# Step 8: Inspect results
# orig 7, guide 6
plt.imshow(x_test_adv[0,...].squeeze())
<matplotlib.image.AxesImage at 0x7fe40034dc40>
# orig 1, guide 5
plt.imshow(x_test_adv[2,...].squeeze())
<matplotlib.image.AxesImage at 0x7fe4003307f0>
# orig 4, guide 9
plt.imshow(x_test_adv[4,...].squeeze())
<matplotlib.image.AxesImage at 0x7fe400294370>
# orig 4, guide 2
plt.imshow(x_test_adv[6,...].squeeze())
<matplotlib.image.AxesImage at 0x7fe40026dc70>
# orig 5, guide 9
plt.imshow(x_test_adv[8,...].squeeze())
<matplotlib.image.AxesImage at 0x7fe4001d53d0>
# orig 0, guide 8
plt.imshow(x_test_adv[10,...].squeeze())
<matplotlib.image.AxesImage at 0x7fe4001e7a30>