import warnings
warnings.filterwarnings('ignore')
import random
import numpy as np
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = [10, 10]
import imagenet_stubs
from imagenet_stubs.imagenet_2012_labels import name_to_label
import tensorflow as tf
sess = tf.InteractiveSession()
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
from tensorflow.keras.preprocessing import image
from art.estimators.classification import TensorFlowClassifier
from art.attacks.evasion import AdversarialPatch
target_name = 'toaster'
image_shape = (224, 224, 3)
clip_values = (0, 255)
nb_classes = 1000
batch_size = 16
scale_min = 0.4
scale_max = 1.0
rotation_max = 22.5
learning_rate = 5000.
max_iter = 500
%%capture
_image_input = tf.keras.Input(shape=image_shape)
_target_ys = tf.placeholder(tf.float32, shape=(None, nb_classes))
model = tf.keras.applications.resnet50.ResNet50(input_tensor=_image_input, weights='imagenet')
_logits = model.outputs[0].op.inputs[0]
target_loss = tf.nn.softmax_cross_entropy_with_logits(labels=_target_ys, logits=_logits)
mean_b = 103.939
mean_g = 116.779
mean_r = 123.680
tfc = TensorFlowClassifier(clip_values=clip_values, input_ph=_image_input, labels_ph=_target_ys,
output=_logits, sess=sess, loss=target_loss,
preprocessing=([mean_b, mean_g, mean_r], 1))
WARNING:tensorflow:From /home/beat/codes/anaconda3/envs/TF1_15/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers. WARNING:tensorflow:From <ipython-input-3-bf2dd6c9f1b3>:5: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version. Instructions for updating: Future major versions of TensorFlow will allow gradients to flow into the labels input on backprop by default. See `tf.nn.softmax_cross_entropy_with_logits_v2`. WARNING:tensorflow:From /home/beat/codes/anaconda3/envs/TF1_15/lib/python3.6/site-packages/art/estimators/classification/tensorflow.py:399: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.
images_list = list()
for image_path in imagenet_stubs.get_image_paths():
im = image.load_img(image_path, target_size=(224, 224))
im = image.img_to_array(im)
im = im[:, :, ::-1].astype(np.float32) # RGB to BGR
im = np.expand_dims(im, axis=0)
images_list.append(im)
images = np.vstack(images_list)
def bgr_to_rgb(x):
return x[:, :, ::-1]
ap = AdversarialPatch(classifier=tfc, rotation_max=rotation_max, scale_min=scale_min, scale_max=scale_max,
learning_rate=learning_rate, max_iter=max_iter, batch_size=batch_size)
label = name_to_label(target_name)
y_one_hot = np.zeros(nb_classes)
y_one_hot[label] = 1.0
y_target = np.tile(y_one_hot, (images.shape[0], 1))
patch, patch_mask = ap.generate(x=images, y=y_target)
plt.imshow((bgr_to_rgb(patch) * patch_mask).astype(np.uint8))
<matplotlib.image.AxesImage at 0x7f55258e5550>
patched_images = ap.apply_patch(images, scale=0.5)
def predict_model(classifier, image):
plt.imshow(bgr_to_rgb(image.astype(np.uint8)))
plt.show()
image = np.copy(image)
image = np.expand_dims(image, axis=0)
prediction = classifier.predict(image)
top = 5
prediction_decode = decode_predictions(prediction, top=top)[0]
print('Predictions:')
lengths = list()
for i in range(top):
lengths.append(len(prediction_decode[i][1]))
max_length = max(lengths)
for i in range(top):
name = prediction_decode[i][1]
name = name.ljust(max_length, " ")
probability = prediction_decode[i][2]
output_str = "{} {:.2f}".format(name, probability)
print(output_str)
predict_model(tfc, patched_images[0])
Predictions: toaster 38.39 bagel 28.03 piggy_bank 17.30 bakery 15.63 pretzel 14.71
predict_model(tfc, patched_images[1])
Predictions: toaster 24.77 beagle 17.96 Walker_hound 13.95 English_foxhound 12.82 piggy_bank 12.43
predict_model(tfc, patched_images[2])
Predictions: toaster 30.13 piggy_bank 13.52 pencil_sharpener 10.79 radio 9.54 teapot 9.39