import os
import sys
from keras.models import model_from_json
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
from aix360.algorithms.contrastive import CEMExplainer, KerasClassifier
from aix360.datasets import MNISTDataset
Using TensorFlow backend.
# load MNIST data and normalize it in the range [-0.5, 0.5]
data = MNISTDataset()
# print the shape of train and test data
print("MNIST train data range :", "(", np.min(data.train_data), ",", np.max(data.train_data), ")")
print("MNIST test data range :", "(", np.min(data.train_data), ",", np.max(data.train_data), ")")
print("MNIST train data shape :", data.train_data.shape)
print("MNIST test data shape :", data.test_data.shape)
print("MNIST train labels shape:", data.test_labels.shape)
print("MNIST test labels shape :", data.test_labels.shape)
MNIST train data range : ( -0.5 , 0.5 ) MNIST test data range : ( -0.5 , 0.5 ) MNIST train data shape : (55000, 28, 28, 1) MNIST test data shape : (10000, 28, 28, 1) MNIST train labels shape: (10000, 10) MNIST test labels shape : (10000, 10)
# path to mnist related models
model_path = '../../aix360/models/CEM'
def load_model(model_json_file, model_wt_file):
# read model json file
with open(model_json_file, 'r') as f:
model = model_from_json(f.read())
# read model weights file
model.load_weights(model_wt_file)
return model
# load MNIST model using its json and wt files
mnist_model = load_model(os.path.join(model_path, 'mnist.json'), os.path.join(model_path, 'mnist'))
# print model summary
mnist_model.summary()
WARNING: Logging before flag parsing goes to stderr. W0803 10:50:01.064849 140678942381888 deprecation_wrapper.py:119] From /home/vijay/anaconda3/envs/aix360/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead. W0803 10:50:01.068910 140678942381888 deprecation_wrapper.py:119] From /home/vijay/anaconda3/envs/aix360/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead. W0803 10:50:01.106316 140678942381888 deprecation_wrapper.py:119] From /home/vijay/anaconda3/envs/aix360/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3976: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead. W0803 10:50:01.249321 140678942381888 deprecation_wrapper.py:119] From /home/vijay/anaconda3/envs/aix360/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead. W0803 10:50:01.250549 140678942381888 deprecation_wrapper.py:119] From /home/vijay/anaconda3/envs/aix360/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead. W0803 10:50:01.251479 140678942381888 deprecation_wrapper.py:119] From /home/vijay/anaconda3/envs/aix360/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:186: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_5 (Conv2D) (None, 26, 26, 32) 320 _________________________________________________________________ activation_7 (Activation) (None, 26, 26, 32) 0 _________________________________________________________________ conv2d_6 (Conv2D) (None, 24, 24, 32) 9248 _________________________________________________________________ activation_8 (Activation) (None, 24, 24, 32) 0 _________________________________________________________________ max_pooling2d_3 (MaxPooling2 (None, 12, 12, 32) 0 _________________________________________________________________ conv2d_7 (Conv2D) (None, 10, 10, 64) 18496 _________________________________________________________________ activation_9 (Activation) (None, 10, 10, 64) 0 _________________________________________________________________ conv2d_8 (Conv2D) (None, 8, 8, 64) 36928 _________________________________________________________________ activation_10 (Activation) (None, 8, 8, 64) 0 _________________________________________________________________ max_pooling2d_4 (MaxPooling2 (None, 4, 4, 64) 0 _________________________________________________________________ flatten_2 (Flatten) (None, 1024) 0 _________________________________________________________________ dense_4 (Dense) (None, 200) 205000 _________________________________________________________________ activation_11 (Activation) (None, 200) 0 _________________________________________________________________ dense_5 (Dense) (None, 200) 40200 _________________________________________________________________ activation_12 (Activation) (None, 200) 0 _________________________________________________________________ dense_6 (Dense) (None, 10) 2010 ================================================================= Total params: 312,202 Trainable params: 312,202 Non-trainable params: 0 _________________________________________________________________
# load the trained convolutional autoencoder model
ae_model = load_model(os.path.join(model_path, 'mnist_AE_1_decoder.json'),
os.path.join(model_path, 'mnist_AE_1_decoder.h5'))
# print model summary
ae_model.summary()
W0803 10:50:01.503559 140678942381888 deprecation_wrapper.py:119] From /home/vijay/anaconda3/envs/aix360/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:2018: The name tf.image.resize_nearest_neighbor is deprecated. Please use tf.compat.v1.image.resize_nearest_neighbor instead.
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= sequential_2 (Sequential) (None, 14, 14, 1) 2625 _________________________________________________________________ conv2d_4 (Conv2D) (None, 14, 14, 16) 160 _________________________________________________________________ activation_3 (Activation) (None, 14, 14, 16) 0 _________________________________________________________________ up_sampling2d_1 (UpSampling2 (None, 28, 28, 16) 0 _________________________________________________________________ conv2d_5 (Conv2D) (None, 28, 28, 16) 2320 _________________________________________________________________ activation_4 (Activation) (None, 28, 28, 16) 0 _________________________________________________________________ conv2d_6 (Conv2D) (None, 28, 28, 1) 145 ================================================================= Total params: 5,250 Trainable params: 5,250 Non-trainable params: 0 _________________________________________________________________
# wrap mnist_model into a framework independent class structure
mymodel = KerasClassifier(mnist_model)
# initialize explainer object
explainer = CEMExplainer(mymodel)
# choose an input image
image_id = 340
input_image = data.test_data[image_id]
# rescale values from [-0.5, 0.5] to [0, 255] for plotting
plt.imshow((input_image[:,:,0] + 0.5)*255, cmap="gray")
# check model prediction
print("Predicted class:", mymodel.predict_classes(np.expand_dims(input_image, axis=0)))
print("Predicted logits:", mymodel.predict(np.expand_dims(input_image, axis=0)))
Predicted class: [3] Predicted logits: [[-11.279337 0.7362492 -9.008647 19.396713 -8.286124 14.442825 -1.3170446 -11.587322 -0.9921855 1.0182219]]
Although the above image is classified as digit 3 by the model, it could have been classified as digit 5 as well since it has similarities to the digit 5. We now employ the CEMExplainer from AIX360 to compute pertinent positive and pertinent negative explanations, which help us understand why the image was classified as digit 3 by the model and not as digit 5.
arg_mode = "PN" # Find pertinent negative
arg_max_iter = 1000 # Maximum number of iterations to search for the optimal PN for given parameter settings
arg_init_const = 10.0 # Initial coefficient value for main loss term that encourages class change
arg_b = 9 # No. of updates to the coefficient of the main loss term
arg_kappa = 10 # Minimum confidence gap between the PNs (changed) class probability and original class' probability
arg_beta = 1e-1 # Controls sparsity of the solution (L1 loss)
arg_gamma = 100 # Controls how much to adhere to a (optionally trained) autoencoder
(adv_pn, delta_pn, info_pn) = explainer.explain_instance(np.expand_dims(input_image, axis=0), arg_mode, ae_model, arg_kappa, arg_b,
arg_max_iter, arg_init_const, arg_beta, arg_gamma)
W0803 10:50:02.457590 140678942381888 deprecation_wrapper.py:119] From /home/vijay/anaconda3/envs/aix360/lib/python3.6/site-packages/aix360/algorithms/contrastive/CEM_aen.py:150: The name tf.train.polynomial_decay is deprecated. Please use tf.compat.v1.train.polynomial_decay instead. W0803 10:50:02.467044 140678942381888 deprecation.py:323] From /home/vijay/anaconda3/envs/aix360/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py:409: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version. Instructions for updating: Deprecated in favor of operator or tf.math.divide. W0803 10:50:02.474853 140678942381888 deprecation_wrapper.py:119] From /home/vijay/anaconda3/envs/aix360/lib/python3.6/site-packages/aix360/algorithms/contrastive/CEM_aen.py:153: The name tf.train.GradientDescentOptimizer is deprecated. Please use tf.compat.v1.train.GradientDescentOptimizer instead. W0803 10:50:02.534118 140678942381888 deprecation.py:323] From /home/vijay/anaconda3/envs/aix360/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.where in 2.0, which has the same broadcast rule as np.where
iter:0 const:[10.] Loss_Overall:117.3156, Loss_Attack:65.3037 Loss_L2Dist:1.1484, Loss_L1Dist:6.2333, AE_loss:50.240135192871094 target_lab_score:15.8465, max_nontarget_lab_score:19.3161 iter:500 const:[10.] Loss_Overall:5731.0386, Loss_Attack:0.0000 Loss_L2Dist:295.1370, Loss_L1Dist:372.2983, AE_loss:5398.67138671875 target_lab_score:-4.0691, max_nontarget_lab_score:9.3513 iter:0 const:[5.] Loss_Overall:100.2937, Loss_Attack:59.5759 Loss_L2Dist:0.6385, Loss_L1Dist:4.1246, AE_loss:39.666893005371094 target_lab_score:18.3216, max_nontarget_lab_score:16.4064 iter:500 const:[5.] Loss_Overall:6569.7109, Loss_Attack:16.6774 Loss_L2Dist:305.4800, Loss_L1Dist:381.6533, AE_loss:6209.38818359375 target_lab_score:1.2297, max_nontarget_lab_score:7.8942 iter:0 const:[2.5] Loss_Overall:74.2821, Loss_Attack:33.0332 Loss_L2Dist:0.5359, Loss_L1Dist:3.6111, AE_loss:40.35193634033203 target_lab_score:18.9114, max_nontarget_lab_score:15.6981 iter:500 const:[2.5] Loss_Overall:6238.6172, Loss_Attack:23.4504 Loss_L2Dist:281.3679, Loss_L1Dist:360.9008, AE_loss:5897.708984375 target_lab_score:2.0889, max_nontarget_lab_score:2.7088 iter:0 const:[1.25] Loss_Overall:59.0087, Loss_Attack:17.0873 Loss_L2Dist:0.5046, Loss_L1Dist:3.5153, AE_loss:41.065223693847656 target_lab_score:19.1260, max_nontarget_lab_score:15.4562 iter:500 const:[1.25] Loss_Overall:6604.8657, Loss_Attack:14.3013 Loss_L2Dist:281.1414, Loss_L1Dist:361.9403, AE_loss:6273.22900390625 target_lab_score:4.8682, max_nontarget_lab_score:3.4272 iter:0 const:[0.625] Loss_Overall:51.0226, Loss_Attack:8.6989 Loss_L2Dist:0.4934, Loss_L1Dist:3.4979, AE_loss:41.48050308227539 target_lab_score:19.2539, max_nontarget_lab_score:15.3357 iter:500 const:[0.625] Loss_Overall:5880.4536, Loss_Attack:4.2246 Loss_L2Dist:307.5166, Loss_L1Dist:392.9896, AE_loss:5529.41357421875 target_lab_score:0.8791, max_nontarget_lab_score:4.1198 iter:0 const:[0.3125] Loss_Overall:47.0586, Loss_Attack:4.3939 Loss_L2Dist:0.4888, Loss_L1Dist:3.4828, AE_loss:41.827728271484375 target_lab_score:19.3268, max_nontarget_lab_score:15.2665 iter:500 const:[0.3125] Loss_Overall:6150.1323, Loss_Attack:2.1761 Loss_L2Dist:291.2777, Loss_L1Dist:368.6633, AE_loss:5819.8125 target_lab_score:0.9448, max_nontarget_lab_score:3.9814 iter:0 const:[0.15625] Loss_Overall:45.0654, Loss_Attack:2.2082 Loss_L2Dist:0.4868, Loss_L1Dist:3.4763, AE_loss:42.02281951904297 target_lab_score:19.3643, max_nontarget_lab_score:15.2319 iter:500 const:[0.15625] Loss_Overall:6265.9688, Loss_Attack:0.4966 Loss_L2Dist:286.9196, Loss_L1Dist:359.0903, AE_loss:5942.6435546875 target_lab_score:-0.7716, max_nontarget_lab_score:6.0504 iter:0 const:[0.078125] Loss_Overall:44.0672, Loss_Attack:1.1069 Loss_L2Dist:0.4858, Loss_L1Dist:3.4732, AE_loss:42.12709426879883 target_lab_score:19.3832, max_nontarget_lab_score:15.2146 iter:500 const:[0.078125] Loss_Overall:6456.3413, Loss_Attack:0.7018 Loss_L2Dist:317.7498, Loss_L1Dist:398.9077, AE_loss:6097.9990234375 target_lab_score:1.5749, max_nontarget_lab_score:2.5923 iter:0 const:[0.0390625] Loss_Overall:43.5620, Loss_Attack:0.5541 Loss_L2Dist:0.4854, Loss_L1Dist:3.4722, AE_loss:42.17523956298828 target_lab_score:19.3921, max_nontarget_lab_score:15.2067 iter:500 const:[0.0390625] Loss_Overall:6328.7046, Loss_Attack:0.3000 Loss_L2Dist:294.4442, Loss_L1Dist:373.3748, AE_loss:5996.623046875 target_lab_score:4.6621, max_nontarget_lab_score:6.9825
print(info_pn)
[INFO]kappa:10, Orig class:3, Perturbed class:5, Delta class: 8, Orig prob:[[-11.279337 0.7362492 -9.008647 19.396713 -8.286124 14.442825 -1.3170446 -11.587322 -0.9921855 1.0182219]], Perturbed prob:[[ -9.823904 -3.0258849 -12.282337 13.914494 -9.164177 24.124342 0.3995837 -12.929418 -0.087002 2.4819136]], Delta prob:[[-0.50663096 -0.84834933 0.7846163 0.59887445 -1.0884676 -0.28783607 -2.5196645 1.3976169 2.129994 0.20926595]]
arg_mode = "PP" # Find pertinent positive
(adv_pp, delta_pp, info_pp) = explainer.explain_instance(np.expand_dims(input_image, axis=0), arg_mode, ae_model, arg_kappa, arg_b,
arg_max_iter, arg_init_const, arg_beta, arg_gamma)
iter:0 const:[10.] Loss_Overall:1271.8158, Loss_Attack:105.5809 Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:1166.23486328125 target_lab_score:-0.1728, max_nontarget_lab_score:0.3853 iter:500 const:[10.] Loss_Overall:1271.8158, Loss_Attack:105.5809 Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:1166.23486328125 target_lab_score:-0.1728, max_nontarget_lab_score:0.3853 iter:0 const:[100.] Loss_Overall:2112.5354, Loss_Attack:859.8530 Loss_L2Dist:3.8316, Loss_L1Dist:8.9619, AE_loss:1247.9547119140625 target_lab_score:2.8555, max_nontarget_lab_score:1.4540 iter:500 const:[100.] Loss_Overall:2813.5281, Loss_Attack:817.9916 Loss_L2Dist:21.7097, Loss_L1Dist:32.3632, AE_loss:1970.5904541015625 target_lab_score:6.2211, max_nontarget_lab_score:4.4010 iter:0 const:[55.] Loss_Overall:1771.9391, Loss_Attack:612.6125 Loss_L2Dist:0.6849, Loss_L1Dist:2.7694, AE_loss:1158.36474609375 target_lab_score:0.2389, max_nontarget_lab_score:1.3773 iter:500 const:[55.] Loss_Overall:2663.4966, Loss_Attack:633.3018 Loss_L2Dist:20.6486, Loss_L1Dist:31.4669, AE_loss:2006.3995361328125 target_lab_score:5.9188, max_nontarget_lab_score:7.4334 iter:0 const:[32.5] Loss_Overall:1496.7809, Loss_Attack:345.4348 Loss_L2Dist:0.0753, Loss_L1Dist:0.5746, AE_loss:1151.2132568359375 target_lab_score:0.0645, max_nontarget_lab_score:0.6933 iter:500 const:[32.5] Loss_Overall:1815.2236, Loss_Attack:264.7164 Loss_L2Dist:12.4209, Loss_L1Dist:25.7885, AE_loss:1535.5074462890625 target_lab_score:8.3142, max_nontarget_lab_score:6.4593 iter:0 const:[21.25] Loss_Overall:1387.5613, Loss_Attack:224.0321 Loss_L2Dist:0.0015, Loss_L1Dist:0.0518, AE_loss:1163.5225830078125 target_lab_score:-0.1502, max_nontarget_lab_score:0.3925 iter:500 const:[21.25] Loss_Overall:1555.0825, Loss_Attack:74.3874 Loss_L2Dist:9.9007, Loss_L1Dist:19.1401, AE_loss:1468.88037109375 target_lab_score:9.0457, max_nontarget_lab_score:2.5462 iter:0 const:[15.625] Loss_Overall:1331.2051, Loss_Attack:164.9702 Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:1166.23486328125 target_lab_score:-0.1728, max_nontarget_lab_score:0.3853 iter:500 const:[15.625] Loss_Overall:1331.2051, Loss_Attack:164.9702 Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:1166.23486328125 target_lab_score:-0.1728, max_nontarget_lab_score:0.3853 iter:0 const:[18.4375] Loss_Overall:1360.8997, Loss_Attack:194.6648 Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:1166.23486328125 target_lab_score:-0.1728, max_nontarget_lab_score:0.3853 iter:500 const:[18.4375] Loss_Overall:1360.8997, Loss_Attack:194.6648 Loss_L2Dist:0.0000, Loss_L1Dist:0.0000, AE_loss:1166.23486328125 target_lab_score:-0.1728, max_nontarget_lab_score:0.3853 iter:0 const:[19.84375] Loss_Overall:1374.3131, Loss_Attack:209.2894 Loss_L2Dist:0.0003, Loss_L1Dist:0.0183, AE_loss:1165.021484375 target_lab_score:-0.1648, max_nontarget_lab_score:0.3821 iter:500 const:[19.84375] Loss_Overall:1758.1250, Loss_Attack:171.5620 Loss_L2Dist:8.9867, Loss_L1Dist:16.8015, AE_loss:1575.8961181640625 target_lab_score:6.1731, max_nontarget_lab_score:4.8187 iter:0 const:[19.140625] Loss_Overall:1367.6649, Loss_Attack:202.0110 Loss_L2Dist:0.0001, Loss_L1Dist:0.0081, AE_loss:1165.653076171875 target_lab_score:-0.1694, max_nontarget_lab_score:0.3846 iter:500 const:[19.140625] Loss_Overall:1673.5901, Loss_Attack:135.8866 Loss_L2Dist:7.5158, Loss_L1Dist:15.7103, AE_loss:1528.61669921875 target_lab_score:6.8402, max_nontarget_lab_score:3.9396
print(info_pp)
[INFO]kappa:10, Orig class:3, Perturbed class:5, Delta class: 3, Orig prob:[[-11.279337 0.7362492 -9.008647 19.396713 -8.286124 14.442825 -1.3170446 -11.587322 -0.9921855 1.0182219]], Perturbed prob:[[-10.512523 2.5331008 -7.8333354 13.738868 -4.697951 13.911682 -0.17094916 -9.634063 -2.869906 -0.1969316 ]], Delta prob:[[-7.0725145 -1.732577 0.35496587 12.47732 -4.7352133 2.2513726 -5.1536884 0.21767989 1.7942489 -1.3421019 ]]
# rescale values from [-0.5, 0.5] to [0, 255] for plotting
fig0 = (input_image[:,:,0] + 0.5)*255
fig1 = (adv_pn[0,:,:,0] + 0.5) * 255
fig2 = (fig1 - fig0) #rescaled delta_pn
fig3 = (adv_pp[0,:,:,0] + 0.5) * 255
fig4 = (fig0 - fig3) #rescaled delta_pp
f, axarr = plt.subplots(1, 5, figsize=(10,10))
axarr[0].set_title("Original" + "(" + str(mymodel.predict_classes(np.expand_dims(input_image, axis=0))[0]) + ")")
axarr[1].set_title("Original + PN" + "(" + str(mymodel.predict_classes(adv_pn)[0]) + ")")
axarr[2].set_title("PN")
axarr[3].set_title("Original + PP")
axarr[4].set_title("PP" + "(" + str(mymodel.predict_classes(delta_pp)[0]) + ")")
axarr[0].imshow(fig0, cmap="gray")
axarr[1].imshow(fig1, cmap="gray")
axarr[2].imshow(fig2, cmap="gray")
axarr[3].imshow(fig3, cmap="gray")
axarr[4].imshow(fig4, cmap="gray")
plt.show()