Demonstrate detection of adversarial samples using ART |
---|
import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.keras.models import load_model
if tf.__version__[0] == "2":
tf.compat.v1.disable_eager_execution()
from art import config
from art.utils import load_dataset, get_file
from art.estimators.classification import KerasClassifier
from art.attacks.evasion import FastGradientMethod
from art.defences.detector.evasion import BinaryInputDetector
import numpy as np
import matplotlib.pyplot as plt
Load the CIFAR10 data set and class descriptions:
(x_train, y_train), (x_test, y_test), min_, max_ = load_dataset('cifar10')
num_samples_train = 100
num_samples_test = 100
x_train = x_train[0:num_samples_train]
y_train = y_train[0:num_samples_train]
x_test = x_test[0:num_samples_test]
y_test = y_test[0:num_samples_test]
class_descr = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Load the pre-trained classifier (a ResNet architecture):
path = get_file('cifar_resnet.h5',extract=False, path=config.ART_DATA_PATH,
url='https://www.dropbox.com/s/ta75pl4krya5djj/cifar_resnet.h5?dl=1')
classifier_model = load_model(path)
classifier = KerasClassifier(clip_values=(min_, max_), model=classifier_model, use_logits=False,
preprocessing=(0.5, 1))
WARNING:tensorflow:From /usr/local/anaconda3/envs/art/lib/python3.8/site-packages/keras/layers/normalization/batch_normalization.py:532: _colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Colocations handled automatically by placer.
2023-01-17 15:00:34.867373: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
classifier_model.summary()
Model: "model_1" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 32, 32, 3)] 0 [] conv2d_1 (Conv2D) (None, 32, 32, 16) 448 ['input_1[0][0]'] batch_normalization_1 (BatchNo (None, 32, 32, 16) 64 ['conv2d_1[0][0]'] rmalization) activation_1 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_1[0][0]'] conv2d_2 (Conv2D) (None, 32, 32, 16) 2320 ['activation_1[0][0]'] batch_normalization_2 (BatchNo (None, 32, 32, 16) 64 ['conv2d_2[0][0]'] rmalization) activation_2 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_2[0][0]'] conv2d_3 (Conv2D) (None, 32, 32, 16) 2320 ['activation_2[0][0]'] add_1 (Add) (None, 32, 32, 16) 0 ['activation_1[0][0]', 'conv2d_3[0][0]'] batch_normalization_3 (BatchNo (None, 32, 32, 16) 64 ['add_1[0][0]'] rmalization) activation_3 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_3[0][0]'] conv2d_4 (Conv2D) (None, 32, 32, 16) 2320 ['activation_3[0][0]'] batch_normalization_4 (BatchNo (None, 32, 32, 16) 64 ['conv2d_4[0][0]'] rmalization) activation_4 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_4[0][0]'] conv2d_5 (Conv2D) (None, 32, 32, 16) 2320 ['activation_4[0][0]'] add_2 (Add) (None, 32, 32, 16) 0 ['add_1[0][0]', 'conv2d_5[0][0]'] batch_normalization_5 (BatchNo (None, 32, 32, 16) 64 ['add_2[0][0]'] rmalization) activation_5 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_5[0][0]'] conv2d_6 (Conv2D) (None, 32, 32, 16) 2320 ['activation_5[0][0]'] batch_normalization_6 (BatchNo (None, 32, 32, 16) 64 ['conv2d_6[0][0]'] rmalization) activation_6 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_6[0][0]'] conv2d_7 (Conv2D) (None, 32, 32, 16) 2320 ['activation_6[0][0]'] add_3 (Add) (None, 32, 32, 16) 0 ['add_2[0][0]', 'conv2d_7[0][0]'] batch_normalization_7 (BatchNo (None, 32, 32, 16) 64 ['add_3[0][0]'] rmalization) activation_7 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_7[0][0]'] conv2d_8 (Conv2D) (None, 32, 32, 16) 2320 ['activation_7[0][0]'] batch_normalization_8 (BatchNo (None, 32, 32, 16) 64 ['conv2d_8[0][0]'] rmalization) activation_8 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_8[0][0]'] conv2d_9 (Conv2D) (None, 32, 32, 16) 2320 ['activation_8[0][0]'] add_4 (Add) (None, 32, 32, 16) 0 ['add_3[0][0]', 'conv2d_9[0][0]'] batch_normalization_9 (BatchNo (None, 32, 32, 16) 64 ['add_4[0][0]'] rmalization) activation_9 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_9[0][0]'] conv2d_10 (Conv2D) (None, 32, 32, 16) 2320 ['activation_9[0][0]'] batch_normalization_10 (BatchN (None, 32, 32, 16) 64 ['conv2d_10[0][0]'] ormalization) activation_10 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_10[0][0]'] conv2d_11 (Conv2D) (None, 32, 32, 16) 2320 ['activation_10[0][0]'] add_5 (Add) (None, 32, 32, 16) 0 ['add_4[0][0]', 'conv2d_11[0][0]'] batch_normalization_11 (BatchN (None, 32, 32, 16) 64 ['add_5[0][0]'] ormalization) activation_11 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_11[0][0]'] conv2d_12 (Conv2D) (None, 16, 16, 32) 4640 ['activation_11[0][0]'] batch_normalization_12 (BatchN (None, 16, 16, 32) 128 ['conv2d_12[0][0]'] ormalization) activation_12 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_12[0][0]'] conv2d_14 (Conv2D) (None, 16, 16, 32) 544 ['add_5[0][0]'] conv2d_13 (Conv2D) (None, 16, 16, 32) 9248 ['activation_12[0][0]'] add_6 (Add) (None, 16, 16, 32) 0 ['conv2d_14[0][0]', 'conv2d_13[0][0]'] batch_normalization_13 (BatchN (None, 16, 16, 32) 128 ['add_6[0][0]'] ormalization) activation_13 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_13[0][0]'] conv2d_15 (Conv2D) (None, 16, 16, 32) 9248 ['activation_13[0][0]'] batch_normalization_14 (BatchN (None, 16, 16, 32) 128 ['conv2d_15[0][0]'] ormalization) activation_14 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_14[0][0]'] conv2d_16 (Conv2D) (None, 16, 16, 32) 9248 ['activation_14[0][0]'] add_7 (Add) (None, 16, 16, 32) 0 ['add_6[0][0]', 'conv2d_16[0][0]'] batch_normalization_15 (BatchN (None, 16, 16, 32) 128 ['add_7[0][0]'] ormalization) activation_15 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_15[0][0]'] conv2d_17 (Conv2D) (None, 16, 16, 32) 9248 ['activation_15[0][0]'] batch_normalization_16 (BatchN (None, 16, 16, 32) 128 ['conv2d_17[0][0]'] ormalization) activation_16 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_16[0][0]'] conv2d_18 (Conv2D) (None, 16, 16, 32) 9248 ['activation_16[0][0]'] add_8 (Add) (None, 16, 16, 32) 0 ['add_7[0][0]', 'conv2d_18[0][0]'] batch_normalization_17 (BatchN (None, 16, 16, 32) 128 ['add_8[0][0]'] ormalization) activation_17 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_17[0][0]'] conv2d_19 (Conv2D) (None, 16, 16, 32) 9248 ['activation_17[0][0]'] batch_normalization_18 (BatchN (None, 16, 16, 32) 128 ['conv2d_19[0][0]'] ormalization) activation_18 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_18[0][0]'] conv2d_20 (Conv2D) (None, 16, 16, 32) 9248 ['activation_18[0][0]'] add_9 (Add) (None, 16, 16, 32) 0 ['add_8[0][0]', 'conv2d_20[0][0]'] batch_normalization_19 (BatchN (None, 16, 16, 32) 128 ['add_9[0][0]'] ormalization) activation_19 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_19[0][0]'] conv2d_21 (Conv2D) (None, 16, 16, 32) 9248 ['activation_19[0][0]'] batch_normalization_20 (BatchN (None, 16, 16, 32) 128 ['conv2d_21[0][0]'] ormalization) activation_20 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_20[0][0]'] conv2d_22 (Conv2D) (None, 16, 16, 32) 9248 ['activation_20[0][0]'] add_10 (Add) (None, 16, 16, 32) 0 ['add_9[0][0]', 'conv2d_22[0][0]'] batch_normalization_21 (BatchN (None, 16, 16, 32) 128 ['add_10[0][0]'] ormalization) activation_21 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_21[0][0]'] conv2d_23 (Conv2D) (None, 8, 8, 64) 18496 ['activation_21[0][0]'] batch_normalization_22 (BatchN (None, 8, 8, 64) 256 ['conv2d_23[0][0]'] ormalization) activation_22 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_22[0][0]'] conv2d_25 (Conv2D) (None, 8, 8, 64) 2112 ['add_10[0][0]'] conv2d_24 (Conv2D) (None, 8, 8, 64) 36928 ['activation_22[0][0]'] add_11 (Add) (None, 8, 8, 64) 0 ['conv2d_25[0][0]', 'conv2d_24[0][0]'] batch_normalization_23 (BatchN (None, 8, 8, 64) 256 ['add_11[0][0]'] ormalization) activation_23 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_23[0][0]'] conv2d_26 (Conv2D) (None, 8, 8, 64) 36928 ['activation_23[0][0]'] batch_normalization_24 (BatchN (None, 8, 8, 64) 256 ['conv2d_26[0][0]'] ormalization) activation_24 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_24[0][0]'] conv2d_27 (Conv2D) (None, 8, 8, 64) 36928 ['activation_24[0][0]'] add_12 (Add) (None, 8, 8, 64) 0 ['add_11[0][0]', 'conv2d_27[0][0]'] batch_normalization_25 (BatchN (None, 8, 8, 64) 256 ['add_12[0][0]'] ormalization) activation_25 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_25[0][0]'] conv2d_28 (Conv2D) (None, 8, 8, 64) 36928 ['activation_25[0][0]'] batch_normalization_26 (BatchN (None, 8, 8, 64) 256 ['conv2d_28[0][0]'] ormalization) activation_26 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_26[0][0]'] conv2d_29 (Conv2D) (None, 8, 8, 64) 36928 ['activation_26[0][0]'] add_13 (Add) (None, 8, 8, 64) 0 ['add_12[0][0]', 'conv2d_29[0][0]'] batch_normalization_27 (BatchN (None, 8, 8, 64) 256 ['add_13[0][0]'] ormalization) activation_27 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_27[0][0]'] conv2d_30 (Conv2D) (None, 8, 8, 64) 36928 ['activation_27[0][0]'] batch_normalization_28 (BatchN (None, 8, 8, 64) 256 ['conv2d_30[0][0]'] ormalization) activation_28 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_28[0][0]'] conv2d_31 (Conv2D) (None, 8, 8, 64) 36928 ['activation_28[0][0]'] add_14 (Add) (None, 8, 8, 64) 0 ['add_13[0][0]', 'conv2d_31[0][0]'] batch_normalization_29 (BatchN (None, 8, 8, 64) 256 ['add_14[0][0]'] ormalization) activation_29 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_29[0][0]'] conv2d_32 (Conv2D) (None, 8, 8, 64) 36928 ['activation_29[0][0]'] batch_normalization_30 (BatchN (None, 8, 8, 64) 256 ['conv2d_32[0][0]'] ormalization) activation_30 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_30[0][0]'] conv2d_33 (Conv2D) (None, 8, 8, 64) 36928 ['activation_30[0][0]'] add_15 (Add) (None, 8, 8, 64) 0 ['add_14[0][0]', 'conv2d_33[0][0]'] batch_normalization_31 (BatchN (None, 8, 8, 64) 256 ['add_15[0][0]'] ormalization) activation_31 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_31[0][0]'] dropout_1 (Dropout) (None, 8, 8, 64) 0 ['activation_31[0][0]'] average_pooling2d_1 (AveragePo (None, 1, 1, 64) 0 ['dropout_1[0][0]'] oling2D) flatten_1 (Flatten) (None, 64) 0 ['average_pooling2d_1[0][0]'] classifier (Dense) (None, 10) 650 ['flatten_1[0][0]'] ================================================================================================== Total params: 470,218 Trainable params: 467,946 Non-trainable params: 2,272 __________________________________________________________________________________________________
Evaluate the classifier on the first 100 test images:
x_test_pred = np.argmax(classifier.predict(x_test[:100]), axis=1)
nb_correct_pred = np.sum(x_test_pred == np.argmax(y_test[:100], axis=1))
print("Original test data (first 100 images):")
print("Correctly classified: {}".format(nb_correct_pred))
print("Incorrectly classified: {}".format(100-nb_correct_pred))
Original test data (first 100 images): Correctly classified: 98 Incorrectly classified: 2
For illustration purposes, look at the first 9 images. (In brackets: true labels.)
plt.figure(figsize=(10,10))
for i in range(0, 9):
pred_label, true_label = class_descr[x_test_pred[i]], class_descr[np.argmax(y_test[i])]
plt.subplot(330 + 1 + i)
fig=plt.imshow(x_test[i])
fig.axes.get_xaxis().set_visible(False)
fig.axes.get_yaxis().set_visible(False)
fig.axes.text(0.5, -0.1, pred_label + " (" + true_label + ")", fontsize=12, transform=fig.axes.transAxes,
horizontalalignment='center')
Generate some adversarial samples:
attacker = FastGradientMethod(classifier, eps=0.05)
x_test_adv = attacker.generate(x_test[:100]) # this takes about two minutes
Evaluate the classifier on 100 adversarial samples:
x_test_adv_pred = np.argmax(classifier.predict(x_test_adv), axis=1)
nb_correct_adv_pred = np.sum(x_test_adv_pred == np.argmax(y_test[:100], axis=1))
print("Adversarial test data (first 100 images):")
print("Correctly classified: {}".format(nb_correct_adv_pred))
print("Incorrectly classified: {}".format(100-nb_correct_adv_pred))
Adversarial test data (first 100 images): Correctly classified: 20 Incorrectly classified: 80
Now plot the adversarial images and their predicted labels (in brackets: true labels).
plt.figure(figsize=(10,10))
for i in range(0, 9):
pred_label, true_label = class_descr[x_test_adv_pred[i]], class_descr[np.argmax(y_test[i])]
plt.subplot(330 + 1 + i)
fig=plt.imshow(x_test_adv[i])
fig.axes.get_xaxis().set_visible(False)
fig.axes.get_yaxis().set_visible(False)
fig.axes.text(0.5, -0.1, pred_label + " (" + true_label + ")", fontsize=12, transform=fig.axes.transAxes,
horizontalalignment='center')
Load the detector model (which also uses a ResNet architecture):
path = get_file('BID_eps=0.05.h5',extract=False, path=config.ART_DATA_PATH,
url='https://www.dropbox.com/s/cbyfk65497wwbtn/BID_eps%3D0.05.h5?dl=1')
detector_model = load_model(path)
detector_classifier = KerasClassifier(clip_values=(-0.5, 0.5), model=detector_model, use_logits=False)
detector = BinaryInputDetector(detector_classifier)
detector_model.summary()
Model: "model_1" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 32, 32, 3)] 0 [] conv2d_1 (Conv2D) (None, 32, 32, 16) 448 ['input_1[0][0]'] batch_normalization_1 (BatchNo (None, 32, 32, 16) 64 ['conv2d_1[0][0]'] rmalization) activation_1 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_1[0][0]'] conv2d_2 (Conv2D) (None, 32, 32, 16) 2320 ['activation_1[0][0]'] batch_normalization_2 (BatchNo (None, 32, 32, 16) 64 ['conv2d_2[0][0]'] rmalization) activation_2 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_2[0][0]'] conv2d_3 (Conv2D) (None, 32, 32, 16) 2320 ['activation_2[0][0]'] add_1 (Add) (None, 32, 32, 16) 0 ['activation_1[0][0]', 'conv2d_3[0][0]'] batch_normalization_3 (BatchNo (None, 32, 32, 16) 64 ['add_1[0][0]'] rmalization) activation_3 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_3[0][0]'] conv2d_4 (Conv2D) (None, 32, 32, 16) 2320 ['activation_3[0][0]'] batch_normalization_4 (BatchNo (None, 32, 32, 16) 64 ['conv2d_4[0][0]'] rmalization) activation_4 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_4[0][0]'] conv2d_5 (Conv2D) (None, 32, 32, 16) 2320 ['activation_4[0][0]'] add_2 (Add) (None, 32, 32, 16) 0 ['add_1[0][0]', 'conv2d_5[0][0]'] batch_normalization_5 (BatchNo (None, 32, 32, 16) 64 ['add_2[0][0]'] rmalization) activation_5 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_5[0][0]'] conv2d_6 (Conv2D) (None, 32, 32, 16) 2320 ['activation_5[0][0]'] batch_normalization_6 (BatchNo (None, 32, 32, 16) 64 ['conv2d_6[0][0]'] rmalization) activation_6 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_6[0][0]'] conv2d_7 (Conv2D) (None, 32, 32, 16) 2320 ['activation_6[0][0]'] add_3 (Add) (None, 32, 32, 16) 0 ['add_2[0][0]', 'conv2d_7[0][0]'] batch_normalization_7 (BatchNo (None, 32, 32, 16) 64 ['add_3[0][0]'] rmalization) activation_7 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_7[0][0]'] conv2d_8 (Conv2D) (None, 32, 32, 16) 2320 ['activation_7[0][0]'] batch_normalization_8 (BatchNo (None, 32, 32, 16) 64 ['conv2d_8[0][0]'] rmalization) activation_8 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_8[0][0]'] conv2d_9 (Conv2D) (None, 32, 32, 16) 2320 ['activation_8[0][0]'] add_4 (Add) (None, 32, 32, 16) 0 ['add_3[0][0]', 'conv2d_9[0][0]'] batch_normalization_9 (BatchNo (None, 32, 32, 16) 64 ['add_4[0][0]'] rmalization) activation_9 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_9[0][0]'] conv2d_10 (Conv2D) (None, 32, 32, 16) 2320 ['activation_9[0][0]'] batch_normalization_10 (BatchN (None, 32, 32, 16) 64 ['conv2d_10[0][0]'] ormalization) activation_10 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_10[0][0]'] conv2d_11 (Conv2D) (None, 32, 32, 16) 2320 ['activation_10[0][0]'] add_5 (Add) (None, 32, 32, 16) 0 ['add_4[0][0]', 'conv2d_11[0][0]'] batch_normalization_11 (BatchN (None, 32, 32, 16) 64 ['add_5[0][0]'] ormalization) activation_11 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_11[0][0]'] conv2d_12 (Conv2D) (None, 16, 16, 32) 4640 ['activation_11[0][0]'] batch_normalization_12 (BatchN (None, 16, 16, 32) 128 ['conv2d_12[0][0]'] ormalization) activation_12 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_12[0][0]'] conv2d_14 (Conv2D) (None, 16, 16, 32) 544 ['add_5[0][0]'] conv2d_13 (Conv2D) (None, 16, 16, 32) 9248 ['activation_12[0][0]'] add_6 (Add) (None, 16, 16, 32) 0 ['conv2d_14[0][0]', 'conv2d_13[0][0]'] batch_normalization_13 (BatchN (None, 16, 16, 32) 128 ['add_6[0][0]'] ormalization) activation_13 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_13[0][0]'] conv2d_15 (Conv2D) (None, 16, 16, 32) 9248 ['activation_13[0][0]'] batch_normalization_14 (BatchN (None, 16, 16, 32) 128 ['conv2d_15[0][0]'] ormalization) activation_14 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_14[0][0]'] conv2d_16 (Conv2D) (None, 16, 16, 32) 9248 ['activation_14[0][0]'] add_7 (Add) (None, 16, 16, 32) 0 ['add_6[0][0]', 'conv2d_16[0][0]'] batch_normalization_15 (BatchN (None, 16, 16, 32) 128 ['add_7[0][0]'] ormalization) activation_15 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_15[0][0]'] conv2d_17 (Conv2D) (None, 16, 16, 32) 9248 ['activation_15[0][0]'] batch_normalization_16 (BatchN (None, 16, 16, 32) 128 ['conv2d_17[0][0]'] ormalization) activation_16 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_16[0][0]'] conv2d_18 (Conv2D) (None, 16, 16, 32) 9248 ['activation_16[0][0]'] add_8 (Add) (None, 16, 16, 32) 0 ['add_7[0][0]', 'conv2d_18[0][0]'] batch_normalization_17 (BatchN (None, 16, 16, 32) 128 ['add_8[0][0]'] ormalization) activation_17 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_17[0][0]'] conv2d_19 (Conv2D) (None, 16, 16, 32) 9248 ['activation_17[0][0]'] batch_normalization_18 (BatchN (None, 16, 16, 32) 128 ['conv2d_19[0][0]'] ormalization) activation_18 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_18[0][0]'] conv2d_20 (Conv2D) (None, 16, 16, 32) 9248 ['activation_18[0][0]'] add_9 (Add) (None, 16, 16, 32) 0 ['add_8[0][0]', 'conv2d_20[0][0]'] batch_normalization_19 (BatchN (None, 16, 16, 32) 128 ['add_9[0][0]'] ormalization) activation_19 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_19[0][0]'] conv2d_21 (Conv2D) (None, 16, 16, 32) 9248 ['activation_19[0][0]'] batch_normalization_20 (BatchN (None, 16, 16, 32) 128 ['conv2d_21[0][0]'] ormalization) activation_20 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_20[0][0]'] conv2d_22 (Conv2D) (None, 16, 16, 32) 9248 ['activation_20[0][0]'] add_10 (Add) (None, 16, 16, 32) 0 ['add_9[0][0]', 'conv2d_22[0][0]'] batch_normalization_21 (BatchN (None, 16, 16, 32) 128 ['add_10[0][0]'] ormalization) activation_21 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_21[0][0]'] conv2d_23 (Conv2D) (None, 8, 8, 64) 18496 ['activation_21[0][0]'] batch_normalization_22 (BatchN (None, 8, 8, 64) 256 ['conv2d_23[0][0]'] ormalization) activation_22 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_22[0][0]'] conv2d_25 (Conv2D) (None, 8, 8, 64) 2112 ['add_10[0][0]'] conv2d_24 (Conv2D) (None, 8, 8, 64) 36928 ['activation_22[0][0]'] add_11 (Add) (None, 8, 8, 64) 0 ['conv2d_25[0][0]', 'conv2d_24[0][0]'] batch_normalization_23 (BatchN (None, 8, 8, 64) 256 ['add_11[0][0]'] ormalization) activation_23 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_23[0][0]'] conv2d_26 (Conv2D) (None, 8, 8, 64) 36928 ['activation_23[0][0]'] batch_normalization_24 (BatchN (None, 8, 8, 64) 256 ['conv2d_26[0][0]'] ormalization) activation_24 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_24[0][0]'] conv2d_27 (Conv2D) (None, 8, 8, 64) 36928 ['activation_24[0][0]'] add_12 (Add) (None, 8, 8, 64) 0 ['add_11[0][0]', 'conv2d_27[0][0]'] batch_normalization_25 (BatchN (None, 8, 8, 64) 256 ['add_12[0][0]'] ormalization) activation_25 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_25[0][0]'] conv2d_28 (Conv2D) (None, 8, 8, 64) 36928 ['activation_25[0][0]'] batch_normalization_26 (BatchN (None, 8, 8, 64) 256 ['conv2d_28[0][0]'] ormalization) activation_26 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_26[0][0]'] conv2d_29 (Conv2D) (None, 8, 8, 64) 36928 ['activation_26[0][0]'] add_13 (Add) (None, 8, 8, 64) 0 ['add_12[0][0]', 'conv2d_29[0][0]'] batch_normalization_27 (BatchN (None, 8, 8, 64) 256 ['add_13[0][0]'] ormalization) activation_27 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_27[0][0]'] conv2d_30 (Conv2D) (None, 8, 8, 64) 36928 ['activation_27[0][0]'] batch_normalization_28 (BatchN (None, 8, 8, 64) 256 ['conv2d_30[0][0]'] ormalization) activation_28 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_28[0][0]'] conv2d_31 (Conv2D) (None, 8, 8, 64) 36928 ['activation_28[0][0]'] add_14 (Add) (None, 8, 8, 64) 0 ['add_13[0][0]', 'conv2d_31[0][0]'] batch_normalization_29 (BatchN (None, 8, 8, 64) 256 ['add_14[0][0]'] ormalization) activation_29 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_29[0][0]'] conv2d_32 (Conv2D) (None, 8, 8, 64) 36928 ['activation_29[0][0]'] batch_normalization_30 (BatchN (None, 8, 8, 64) 256 ['conv2d_32[0][0]'] ormalization) activation_30 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_30[0][0]'] conv2d_33 (Conv2D) (None, 8, 8, 64) 36928 ['activation_30[0][0]'] add_15 (Add) (None, 8, 8, 64) 0 ['add_14[0][0]', 'conv2d_33[0][0]'] batch_normalization_31 (BatchN (None, 8, 8, 64) 256 ['add_15[0][0]'] ormalization) activation_31 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_31[0][0]'] dropout_1 (Dropout) (None, 8, 8, 64) 0 ['activation_31[0][0]'] average_pooling2d_1 (AveragePo (None, 1, 1, 64) 0 ['dropout_1[0][0]'] oling2D) flatten_1 (Flatten) (None, 64) 0 ['average_pooling2d_1[0][0]'] classifier (Dense) (None, 2) 130 ['flatten_1[0][0]'] ================================================================================================== Total params: 469,698 Trainable params: 467,426 Non-trainable params: 2,272 __________________________________________________________________________________________________
To train the detector:
x_train_adv = attacker.generate(x_train)
nb_train = x_train.shape[0]
x_train_detector = np.concatenate((x_train, x_train_adv), axis=0)
y_train_detector = np.concatenate((np.array([[1,0]]*nb_train), np.array([[0,1]]*nb_train)), axis=0)
Perform the training:
detector.fit(x_train_detector, y_train_detector, nb_epochs=20, batch_size=20)
Train on 200 samples Epoch 1/20 200/200 [==============================] - 4s 22ms/sample - loss: 0.0062 - accuracy: 1.0000 Epoch 2/20 200/200 [==============================] - 1s 7ms/sample - loss: 0.0057 - accuracy: 1.0000 Epoch 3/20 200/200 [==============================] - 2s 8ms/sample - loss: 0.0056 - accuracy: 1.0000 Epoch 4/20 200/200 [==============================] - 2s 9ms/sample - loss: 0.0076 - accuracy: 1.0000 Epoch 5/20 200/200 [==============================] - 2s 11ms/sample - loss: 0.0062 - accuracy: 1.0000 Epoch 6/20 200/200 [==============================] - 2s 9ms/sample - loss: 0.0052 - accuracy: 1.0000 Epoch 7/20 200/200 [==============================] - 2s 8ms/sample - loss: 0.0051 - accuracy: 1.0000 Epoch 8/20 200/200 [==============================] - 2s 8ms/sample - loss: 0.0145 - accuracy: 0.9950 Epoch 9/20 200/200 [==============================] - 2s 9ms/sample - loss: 0.0334 - accuracy: 0.9950 Epoch 10/20 200/200 [==============================] - 2s 9ms/sample - loss: 0.0053 - accuracy: 1.0000 Epoch 11/20 200/200 [==============================] - 2s 9ms/sample - loss: 0.0053 - accuracy: 1.0000 Epoch 12/20 200/200 [==============================] - 2s 9ms/sample - loss: 0.0049 - accuracy: 1.0000 Epoch 13/20 200/200 [==============================] - 2s 9ms/sample - loss: 0.0048 - accuracy: 1.0000 Epoch 14/20 200/200 [==============================] - 2s 9ms/sample - loss: 0.0049 - accuracy: 1.0000 Epoch 15/20 200/200 [==============================] - 2s 9ms/sample - loss: 0.0051 - accuracy: 1.0000 Epoch 16/20 200/200 [==============================] - 2s 10ms/sample - loss: 0.0047 - accuracy: 1.0000 Epoch 17/20 200/200 [==============================] - 2s 10ms/sample - loss: 0.0045 - accuracy: 1.0000 Epoch 18/20 200/200 [==============================] - 2s 9ms/sample - loss: 0.0043 - accuracy: 1.0000 Epoch 19/20 200/200 [==============================] - 2s 10ms/sample - loss: 0.0058 - accuracy: 1.0000 Epoch 20/20 200/200 [==============================] - 2s 10ms/sample - loss: 0.0043 - accuracy: 1.0000
Apply the detector to the adversarial test data:
_, is_adversarial = detector.detect(x_test_adv)
flag_adv = np.sum(is_adversarial)
print("Adversarial test data (first 100 images):")
print("Flagged: {}".format(flag_adv))
print("Not flagged: {}".format(100 - flag_adv))
Adversarial test data (first 100 images): Flagged: 100 Not flagged: 0
Apply the detector to the first 100 original test images:
_, is_adversarial = detector.detect(x_test[:100])
flag_original = np.sum(is_adversarial)
print("Original test data (first 100 images):")
print("Flagged: {}".format(flag_original))
print("Not flagged: {}".format(100 - flag_original))
Original test data (first 100 images): Flagged: 100 Not flagged: 0
Evaluate the detector for different attack strengths eps
(Note: for the training of detector, eps=0.05
was used)
eps_range = [0.01, 0.02, 0.03, 0.04, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
nb_flag_adv = []
nb_missclass = []
for eps in eps_range:
attacker.set_params(**{'eps': eps})
x_test_adv = attacker.generate(x_test[:100])
nb_flag_adv += [np.sum(detector.detect(x_test_adv)[1])]
nb_missclass += [np.sum(np.argmax(classifier.predict(x_test_adv), axis=1) != np.argmax(y_test[:100], axis=1))]
eps_range = [0] + eps_range
nb_flag_adv = [flag_original] + nb_flag_adv
nb_missclass = [2] + nb_missclass
fig, ax = plt.subplots()
ax.plot(np.array(eps_range)[:8], np.array(nb_flag_adv)[:8], 'b--', label='Detector flags')
ax.plot(np.array(eps_range)[:8], np.array(nb_missclass)[:8], 'r--', label='Classifier errors')
legend = ax.legend(loc='center right', shadow=True, fontsize='large')
legend.get_frame().set_facecolor('#00FFCC')
plt.xlabel('Attack strength (eps)')
plt.ylabel('Per 100 adversarial samples')
plt.show()