CEM_MAFImageExplainer from AIX360 can be used to obtain contrastive explanations on data that have pre-defined high-level attributes, such as facial images that are annotated with features such as smile, high cheekbones, makeup, etc.
The goal of this tutorial is to demonstrate the use of CEM_MAFImageExplainer, which offers two-part explanations based on a pertinent positive and a pertinent negative. The pertinent positive explanation outputs the minimal set of high-level features that must be present in order for the classification of a sample to remain the same, so that if any one of the output features was missing from the sample, the classification would be different. The pertinent negative explanation outputs a set of features that would cause a change to the classification if they were added to the sample.
import tensorflow as tf
import sys
import os
import numpy as np
import random
import matplotlib.pyplot as plt
from zipfile import ZipFile
from aix360.algorithms.contrastive import CEM_MAFImageExplainer
from aix360.algorithms.contrastive import CELEBAModel
from aix360.algorithms.contrastive import KerasClassifier
from aix360.algorithms.contrastive.dwnld_CEM_MAF_celebA import dwnld_CEM_MAF_celebA
from aix360.datasets.celeba_dataset import CelebADataset
dwnld = dwnld_CEM_MAF_celebA()
sess = tf.InteractiveSession()
random.seed(120)
np.random.seed(1210)
sess.run(tf.global_variables_initializer())
# Download pretrained celebA model
local_path_models = '../../aix360/models/CEM_MAF'
celebA_model_file = dwnld.dwnld_celebA_model(local_path_models)
celebA model file downloaded: ['../../aix360/models/CEM_MAF/celebA']
# Load the downloaded celebA model
model_file = '../../aix360/models/CEM_MAF/celebA'
loaded_model = CELEBAModel(restore=model_file, use_softmax=False).model
Load: ../../aix360/models/CEM_MAF/celebA
mymodel = KerasClassifier(loaded_model)
img_id = 15
local_path_img = '../../aix360/data/celeba_data'
img_files = dwnld.dwnld_celebA_data(local_path_img, [img_id])
Image files downloaded: ['../../aix360/data/celeba_data/15_img.npy', '../../aix360/data/celeba_data/15img.png', '../../aix360/data/celeba_data/15_latent.npy']
dataset_obj = CelebADataset(local_path_img) # use the CelebA dataset class
input_img = dataset_obj.get_img(img_id)
input_latent = dataset_obj.get_latent(img_id)
# images are processed according to needs for model being explained
input_img = np.clip(input_img/2, -0.5, 0.5)
plt.axis("off")
plt.imshow(input_img[0,:,:,:]+0.5)
plt.show()
orig_prob, orig_class, orig_prob_str = mymodel.predict_long(input_img)
# Compute classes
young_flag = orig_class % 2
smile_flag = (orig_class // 2) % 2
sex_flag = (orig_class // 4) % 2
arg_img_name = os.path.join(local_path_img, "{}_img.png".format(img_id))
print("Image:{}, pred:{}".format(arg_img_name, orig_class))
print("Male:{}, Smile:{}, Young:{}".format(sex_flag, smile_flag, young_flag))
orig_img = input_img
target_label = [np.eye(mymodel._nb_classes)[orig_class]]
Image:../../aix360/data/celeba_data/15_img.png, pred:4 Male:1, Smile:0, Young:0
attributes = ["Black_Hair", "Blond_Hair", "Brown_Hair", "Gray_Hair", "Wearing_Lipstick", "Heavy_Makeup",\
"High_Cheekbones", "Bangs", "Oval_Face", "Narrow_Eyes", "Bags_Under_Eyes", "Pointy_Nose"]
# Download attribute functions
attr_model_files = dwnld.dwnld_celebA_attributes(local_path_models, attributes)
Attribute files downloaded: ['../../aix360/models/CEM_MAF/simple_Black_Hair_model.json', '../../aix360/models/CEM_MAF/simple_Black_Hair_weights.h5', '../../aix360/models/CEM_MAF/simple_Black_Hair.ckpt', '../../aix360/models/CEM_MAF/simple_Blond_Hair.ckpt', '../../aix360/models/CEM_MAF/simple_Blond_Hair_weights.h5', '../../aix360/models/CEM_MAF/simple_Blond_Hair_model.json', '../../aix360/models/CEM_MAF/simple_Brown_Hair_weights.h5', '../../aix360/models/CEM_MAF/simple_Brown_Hair_model.json', '../../aix360/models/CEM_MAF/simple_Brown_Hair.ckpt', '../../aix360/models/CEM_MAF/simple_Gray_Hair.ckpt', '../../aix360/models/CEM_MAF/simple_Gray_Hair_weights.h5', '../../aix360/models/CEM_MAF/simple_Gray_Hair_model.json', '../../aix360/models/CEM_MAF/simple_Wearing_Lipstick_model.json', '../../aix360/models/CEM_MAF/simple_Wearing_Lipstick_weights.h5', '../../aix360/models/CEM_MAF/simple_Wearing_Lipstick.ckpt', '../../aix360/models/CEM_MAF/simple_Heavy_Makeup.ckpt', '../../aix360/models/CEM_MAF/simple_Heavy_Makeup_model.json', '../../aix360/models/CEM_MAF/simple_Heavy_Makeup_weights.h5', '../../aix360/models/CEM_MAF/simple_High_Cheekbones_weights.h5', '../../aix360/models/CEM_MAF/simple_High_Cheekbones.ckpt', '../../aix360/models/CEM_MAF/simple_High_Cheekbones_model.json', '../../aix360/models/CEM_MAF/simple_Bangs_model.json', '../../aix360/models/CEM_MAF/simple_Bangs_weights.h5', '../../aix360/models/CEM_MAF/simple_Bangs.ckpt', '../../aix360/models/CEM_MAF/simple_Oval_Face_weights.h5', '../../aix360/models/CEM_MAF/simple_Oval_Face.ckpt', '../../aix360/models/CEM_MAF/simple_Oval_Face_model.json', '../../aix360/models/CEM_MAF/simple_Narrow_Eyes_model.json', '../../aix360/models/CEM_MAF/simple_Narrow_Eyes.ckpt', '../../aix360/models/CEM_MAF/simple_Narrow_Eyes_weights.h5', '../../aix360/models/CEM_MAF/simple_Bags_Under_Eyes.ckpt', '../../aix360/models/CEM_MAF/simple_Bags_Under_Eyes_model.json', '../../aix360/models/CEM_MAF/simple_Bags_Under_Eyes_weights.h5', '../../aix360/models/CEM_MAF/simple_Pointy_Nose_model.json', '../../aix360/models/CEM_MAF/simple_Pointy_Nose_weights.h5', '../../aix360/models/CEM_MAF/simple_Pointy_Nose.ckpt', '../../aix360/models/CEM_MAF/attr_model/simple_Black_Hair_model.json', '../../aix360/models/CEM_MAF/attr_model/simple_Black_Hair_weights.h5', '../../aix360/models/CEM_MAF/attr_model/simple_Black_Hair.ckpt', '../../aix360/models/CEM_MAF/attr_model/simple_Blond_Hair.ckpt', '../../aix360/models/CEM_MAF/attr_model/simple_Blond_Hair_weights.h5', '../../aix360/models/CEM_MAF/attr_model/simple_Blond_Hair_model.json', '../../aix360/models/CEM_MAF/attr_model/simple_Gray_Hair.ckpt', '../../aix360/models/CEM_MAF/attr_model/simple_Gray_Hair_weights.h5', '../../aix360/models/CEM_MAF/attr_model/simple_Gray_Hair_model.json', '../../aix360/models/CEM_MAF/attr_model/simple_Wearing_Lipstick_model.json', '../../aix360/models/CEM_MAF/attr_model/simple_Wearing_Lipstick_weights.h5', '../../aix360/models/CEM_MAF/attr_model/simple_Wearing_Lipstick.ckpt', '../../aix360/models/CEM_MAF/attr_model/simple_Heavy_Makeup.ckpt', '../../aix360/models/CEM_MAF/attr_model/simple_Heavy_Makeup_model.json', '../../aix360/models/CEM_MAF/attr_model/simple_Heavy_Makeup_weights.h5', '../../aix360/models/CEM_MAF/attr_model/simple_Bangs_model.json', '../../aix360/models/CEM_MAF/attr_model/simple_Bangs_weights.h5', '../../aix360/models/CEM_MAF/attr_model/simple_Bangs.ckpt', '../../aix360/models/CEM_MAF/attr_model/simple_Oval_Face_weights.h5', '../../aix360/models/CEM_MAF/attr_model/simple_Oval_Face.ckpt', '../../aix360/models/CEM_MAF/attr_model/simple_Oval_Face_model.json', '../../aix360/models/CEM_MAF/attr_model/simple_Narrow_Eyes_model.json', '../../aix360/models/CEM_MAF/attr_model/simple_Narrow_Eyes.ckpt', '../../aix360/models/CEM_MAF/attr_model/simple_Narrow_Eyes_weights.h5', '../../aix360/models/CEM_MAF/attr_model/simple_Bags_Under_Eyes.ckpt', '../../aix360/models/CEM_MAF/attr_model/simple_Bags_Under_Eyes_model.json', '../../aix360/models/CEM_MAF/attr_model/simple_Bags_Under_Eyes_weights.h5', '../../aix360/models/CEM_MAF/attr_model/simple_Pointy_Nose_model.json', '../../aix360/models/CEM_MAF/attr_model/simple_Pointy_Nose_weights.h5', '../../aix360/models/CEM_MAF/attr_model/simple_Pointy_Nose.ckpt']
aix360_path = '../../aix360' # needed to find paths to attribute files
explainer = CEM_MAFImageExplainer(mymodel, attributes, aix360_path)
# parameter values for the pertinent negative
arg_mode = 'PN'
arg_kappa = 5
arg_gamma = 1
arg_binary_search_steps = 1
arg_max_iterations = 250
arg_initial_const = 10
arg_attr_reg = 100.0
arg_attr_penalty_reg = 100.0
arg_latent_square_loss_reg = 1.0
(adv_pn, attr_pn, info_pn) = explainer.explain_instance(sess, input_img,
input_latent, arg_mode, arg_kappa, arg_binary_search_steps,
arg_max_iterations, arg_initial_const, arg_gamma, None,
arg_attr_reg, arg_attr_penalty_reg,
arg_latent_square_loss_reg)
print(info_pn)
Loaded model for Black_Hair from disk Loaded model for Blond_Hair from disk Loaded model for Brown_Hair from disk Loaded model for Gray_Hair from disk Loaded model for Wearing_Lipstick from disk Loaded model for Heavy_Makeup from disk Loaded model for High_Cheekbones from disk Loaded model for Bangs from disk Loaded model for Oval_Face from disk Loaded model for Narrow_Eyes from disk Loaded model for Bags_Under_Eyes from disk Loaded model for Pointy_Nose from disk # of attr models is 12 iter:0 const:[10.] Loss_Overall:7385.9272, Loss_Attack:0.0000, Loss_attr:1.2358 Loss_Latent_L2Dist:20.1870, Loss_Img_L2Dist:7021.6318 target_lab_score:-2.7435, max_nontarget_lab_score:5.0185 iter:10 const:[10.] Loss_Overall:5924.4990, Loss_Attack:0.0000, Loss_attr:0.8909 Loss_Latent_L2Dist:1159.0074, Loss_Img_L2Dist:4563.6479 target_lab_score:0.6903, max_nontarget_lab_score:6.1033 iter:20 const:[10.] Loss_Overall:3637.8708, Loss_Attack:0.0000, Loss_attr:0.5807 Loss_Latent_L2Dist:789.6542, Loss_Img_L2Dist:2642.2075 target_lab_score:-1.1012, max_nontarget_lab_score:6.5360 iter:30 const:[10.] Loss_Overall:3287.6770, Loss_Attack:0.0000, Loss_attr:0.9050 Loss_Latent_L2Dist:654.8574, Loss_Img_L2Dist:2436.0181 target_lab_score:-2.9968, max_nontarget_lab_score:6.3617 iter:40 const:[10.] Loss_Overall:2919.0195, Loss_Attack:0.0000, Loss_attr:0.8444 Loss_Latent_L2Dist:541.7762, Loss_Img_L2Dist:2199.4700 target_lab_score:-2.4757, max_nontarget_lab_score:5.8496 iter:50 const:[10.] Loss_Overall:3176.9578, Loss_Attack:0.0000, Loss_attr:0.8826 Loss_Latent_L2Dist:456.5536, Loss_Img_L2Dist:2545.3472 target_lab_score:-2.4155, max_nontarget_lab_score:6.0877 iter:60 const:[10.] Loss_Overall:3039.5088, Loss_Attack:0.0000, Loss_attr:0.8963 Loss_Latent_L2Dist:415.3625, Loss_Img_L2Dist:2443.7905 target_lab_score:-2.1841, max_nontarget_lab_score:6.0178 iter:70 const:[10.] Loss_Overall:2210.2634, Loss_Attack:0.0000, Loss_attr:0.8182 Loss_Latent_L2Dist:388.4205, Loss_Img_L2Dist:1652.5145 target_lab_score:0.4080, max_nontarget_lab_score:6.5039 iter:80 const:[10.] Loss_Overall:2101.1548, Loss_Attack:0.0000, Loss_attr:0.6353 Loss_Latent_L2Dist:371.8968, Loss_Img_L2Dist:1564.4413 target_lab_score:0.5502, max_nontarget_lab_score:6.1577 iter:90 const:[10.] Loss_Overall:3215.2903, Loss_Attack:0.0000, Loss_attr:0.9105 Loss_Latent_L2Dist:379.6212, Loss_Img_L2Dist:2653.0435 target_lab_score:-3.3713, max_nontarget_lab_score:6.9207 iter:100 const:[10.] Loss_Overall:2545.1436, Loss_Attack:69.2856, Loss_attr:1.0424 Loss_Latent_L2Dist:354.3192, Loss_Img_L2Dist:1930.1128 target_lab_score:4.2894, max_nontarget_lab_score:2.3609 iter:110 const:[10.] Loss_Overall:2443.0845, Loss_Attack:0.0000, Loss_attr:0.7599 Loss_Latent_L2Dist:346.3373, Loss_Img_L2Dist:1924.0270 target_lab_score:-1.3882, max_nontarget_lab_score:7.5135 iter:120 const:[10.] Loss_Overall:2219.9146, Loss_Attack:0.0000, Loss_attr:1.0280 Loss_Latent_L2Dist:340.4952, Loss_Img_L2Dist:1691.4609 target_lab_score:-1.6604, max_nontarget_lab_score:7.8813 iter:130 const:[10.] Loss_Overall:2027.2773, Loss_Attack:0.0000, Loss_attr:0.6175 Loss_Latent_L2Dist:344.2769, Loss_Img_L2Dist:1516.4824 target_lab_score:-1.8672, max_nontarget_lab_score:8.1930 iter:140 const:[10.] Loss_Overall:2343.8447, Loss_Attack:0.0000, Loss_attr:0.3613 Loss_Latent_L2Dist:337.6933, Loss_Img_L2Dist:1823.3728 target_lab_score:-0.9989, max_nontarget_lab_score:6.6355 iter:150 const:[10.] Loss_Overall:2120.2942, Loss_Attack:0.0000, Loss_attr:0.5692 Loss_Latent_L2Dist:339.6035, Loss_Img_L2Dist:1596.5542 target_lab_score:-0.5275, max_nontarget_lab_score:6.9749 iter:160 const:[10.] Loss_Overall:2194.8979, Loss_Attack:0.0000, Loss_attr:0.5159 Loss_Latent_L2Dist:346.0681, Loss_Img_L2Dist:1660.7866 target_lab_score:-0.7260, max_nontarget_lab_score:6.8917 iter:170 const:[10.] Loss_Overall:2069.8169, Loss_Attack:0.0000, Loss_attr:0.7354 Loss_Latent_L2Dist:341.7602, Loss_Img_L2Dist:1541.4486 target_lab_score:-0.8644, max_nontarget_lab_score:7.5602 iter:180 const:[10.] Loss_Overall:2006.3484, Loss_Attack:0.0000, Loss_attr:0.5715 Loss_Latent_L2Dist:333.2462, Loss_Img_L2Dist:1493.0411 target_lab_score:-2.0964, max_nontarget_lab_score:8.4303 iter:190 const:[10.] Loss_Overall:2656.1536, Loss_Attack:87.6424, Loss_attr:0.5912 Loss_Latent_L2Dist:316.9724, Loss_Img_L2Dist:2055.6011 target_lab_score:4.6402, max_nontarget_lab_score:0.8760 iter:200 const:[10.] Loss_Overall:1809.2095, Loss_Attack:18.6715, Loss_attr:0.6551 Loss_Latent_L2Dist:311.5544, Loss_Img_L2Dist:1295.4883 target_lab_score:2.3426, max_nontarget_lab_score:5.4755 iter:210 const:[10.] Loss_Overall:2163.3423, Loss_Attack:0.0000, Loss_attr:0.8297 Loss_Latent_L2Dist:309.7028, Loss_Img_L2Dist:1666.4037 target_lab_score:-0.6750, max_nontarget_lab_score:7.2690 iter:220 const:[10.] Loss_Overall:1676.7032, Loss_Attack:0.0000, Loss_attr:0.4670 Loss_Latent_L2Dist:293.3751, Loss_Img_L2Dist:1206.6770 target_lab_score:-1.7596, max_nontarget_lab_score:8.1373 iter:230 const:[10.] Loss_Overall:1574.7423, Loss_Attack:0.0000, Loss_attr:0.6542 Loss_Latent_L2Dist:279.5175, Loss_Img_L2Dist:1113.9025 target_lab_score:0.3418, max_nontarget_lab_score:7.1122 iter:240 const:[10.] Loss_Overall:2083.2930, Loss_Attack:0.0000, Loss_attr:0.4539 Loss_Latent_L2Dist:265.6353, Loss_Img_L2Dist:1643.6389 target_lab_score:-1.6940, max_nontarget_lab_score:7.6957 [INFO] Orig class:4, Adv class:6, Orig prob:[[-5.7961226 -6.4976497 -5.1008477 -6.8349266 4.7267528 -3.3094192 -3.0575497 -4.120827 ]], Adv prob:[[-6.0269275 -4.697468 -3.5946152 -4.4234085 -0.62827617 -5.6178136 7.752234 0.4052744 ]]
plt.axis("off")
plt.imshow(adv_pn[0,:,:,:]+0.5)
plt.show()
# Compute new classes
adv_prob, adv_class, adv_prob_str = mymodel.predict_long(adv_pn)
young_flag = adv_class % 2
smile_flag = (adv_class // 2) % 2
sex_flag = (adv_class // 4) % 2
print("Pertinent Negative pred:{}".format(adv_class))
print("Male:{}, Smile:{}, Young:{}".format(sex_flag, smile_flag, young_flag))
print(attr_pn)
Pertinent Negative pred:6 Male:1, Smile:1, Young:0 Added High_Cheekbones
# Parameter values for the pertinent positive
# Note that regularization parameters need for pertinent negative are not need here
arg_mode = 'PP'
arg_kappa = 5
arg_gamma = 100.0
arg_beta = 0.1
arg_binary_search_steps = 1
arg_max_iterations = 100
arg_initial_const = 10
(adv_pp, __, __) = explainer.explain_instance(sess, input_img, None, arg_mode, arg_kappa,
arg_binary_search_steps, arg_max_iterations,
arg_initial_const, arg_gamma, arg_beta)
Creating a mask for pertinent positive Loaded model for Black_Hair from disk Loaded model for Blond_Hair from disk Loaded model for Brown_Hair from disk Loaded model for Gray_Hair from disk Loaded model for Wearing_Lipstick from disk Loaded model for Heavy_Makeup from disk Loaded model for High_Cheekbones from disk Loaded model for Bangs from disk Loaded model for Oval_Face from disk Loaded model for Narrow_Eyes from disk Loaded model for Bags_Under_Eyes from disk Loaded model for Pointy_Nose from disk # of attr models is 12 iter:0 const:[10.] Loss_Overall:4382.9561, Loss_Attack:0.0000, Loss_attr:0.0771 Loss_L2Dist:39100.1719, Loss_L1Dist:43718.0234, AE_loss:0.0 target_lab_score:4.3499, max_nontarget_lab_score:-7.6034 iter:10 const:[10.] Loss_Overall:436.1061, Loss_Attack:47.2062, Loss_attr:0.3453 Loss_L2Dist:2466.2612, Loss_L1Dist:3463.2842, AE_loss:0.0 target_lab_score:-2.1829, max_nontarget_lab_score:-2.4623 iter:20 const:[10.] Loss_Overall:698.1842, Loss_Attack:0.0000, Loss_attr:0.1899 Loss_L2Dist:5231.9268, Loss_L1Dist:6817.3809, AE_loss:0.0 target_lab_score:1.2739, max_nontarget_lab_score:-5.9582 iter:30 const:[10.] Loss_Overall:624.0765, Loss_Attack:0.0000, Loss_attr:0.3732 Loss_L2Dist:4226.5225, Loss_L1Dist:5768.5303, AE_loss:0.0 target_lab_score:2.6131, max_nontarget_lab_score:-9.5929 iter:40 const:[10.] Loss_Overall:715.3864, Loss_Attack:0.0000, Loss_attr:0.2153 Loss_L2Dist:5176.5283, Loss_L1Dist:6829.7847, AE_loss:0.0 target_lab_score:3.5829, max_nontarget_lab_score:-4.1920 iter:50 const:[10.] Loss_Overall:555.2651, Loss_Attack:0.0000, Loss_attr:0.5205 Loss_L2Dist:3711.5945, Loss_L1Dist:4872.4102, AE_loss:0.0 target_lab_score:2.4348, max_nontarget_lab_score:-2.5788 iter:60 const:[10.] Loss_Overall:369.6270, Loss_Attack:35.8543, Loss_attr:0.2620 Loss_L2Dist:2239.4399, Loss_L1Dist:3035.5767, AE_loss:0.0 target_lab_score:-0.5174, max_nontarget_lab_score:-1.9320 iter:70 const:[10.] Loss_Overall:330.4863, Loss_Attack:52.4312, Loss_attr:0.6744 Loss_L2Dist:1501.2345, Loss_L1Dist:2045.4265, AE_loss:0.0 target_lab_score:-0.3374, max_nontarget_lab_score:-0.0943 iter:80 const:[10.] Loss_Overall:444.5557, Loss_Attack:0.0000, Loss_attr:0.7794 Loss_L2Dist:2646.3767, Loss_L1Dist:3589.2808, AE_loss:0.0 target_lab_score:1.3040, max_nontarget_lab_score:-9.0034 iter:90 const:[10.] Loss_Overall:279.2845, Loss_Attack:68.8094, Loss_attr:0.4646 Loss_L2Dist:800.1670, Loss_L1Dist:1687.4364, AE_loss:0.0 target_lab_score:-4.7829, max_nontarget_lab_score:-2.9020 Generating the pertinent positive Start ranking: i:1, index:60, value:0.9562264680862427, class:5 i:2, index:85, value:0.9243737459182739, class:7 i:3, index:91, value:0.8686531782150269, class:5 i:4, index:54, value:0.8284841775894165, class:5 i:5, index:107, value:0.7462365031242371, class:5 i:6, index:86, value:0.6834268569946289, class:4
plt.axis("off")
plt.imshow(adv_pp[0,:,:,:]+0.5)
plt.show()
# Compute class of PP
adv_prob, adv_class, adv_prob_str = mymodel.predict_long(adv_pp)
young_flag = adv_class % 2
smile_flag = (adv_class // 2) % 2
sex_flag = (adv_class // 4) % 2
print("Pertinent positive pred:{}".format(adv_class))
print("Male:{}, Smile:{}, Young:{}".format(sex_flag, smile_flag, young_flag))
Pertinent positive pred:4 Male:1, Smile:0, Young:0