%%HTML
%tensorflow_version 1.x
import numpy as np
import keras
from keras.applications.imagenet_utils import decode_predictions
import skimage.io
import skimage.segmentation
import copy
import sklearn
import sklearn.metrics
from sklearn.linear_model import LinearRegression
import warnings
print('Notebook running: keras ', keras.__version__)
np.random.seed(222)
warnings.filterwarnings('ignore')
inceptionV3_model = keras.applications.inception_v3.InceptionV3() #Load pretrained model
Xi = skimage.io.imread("https://arteagac.github.io/blog/lime_image/img/cat-and-dog.jpg")
Xi = skimage.transform.resize(Xi, (299,299))
Xi = (Xi - 0.5)*2 #Inception pre-processing
skimage.io.imshow(Xi/2+0.5) # Show image before inception preprocessing
np.random.seed(222)
preds = inceptionV3_model.predict(Xi[np.newaxis,:,:,:])
decode_predictions(preds)[0] #Top 5 classes
top_pred_classes = preds[0].argsort()[-5:][::-1]
top_pred_classes #Index of top 5 classes
superpixels = skimage.segmentation.quickshift(Xi, kernel_size=4,max_dist=200, ratio=0.2)
num_superpixels = np.unique(superpixels).shape[0]
num_superpixels
skimage.io.imshow(skimage.segmentation.mark_boundaries(Xi/2+0.5, superpixels))
num_perturb = 150
perturbations = np.random.binomial(1, 0.5, size=(num_perturb, num_superpixels))
perturbations[0] #Show example of perturbation
def perturb_image(img,perturbation,segments):
active_pixels = np.where(perturbation == 1)[0]
mask = np.zeros(segments.shape)
for active in active_pixels:
mask[segments == active] = 1
perturbed_image = copy.deepcopy(img)
perturbed_image = perturbed_image*mask[:,:,np.newaxis]
return perturbed_image
skimage.io.imshow(perturb_image(Xi/2+0.5,perturbations[0],superpixels))
predictions = []
for pert in perturbations:
perturbed_img = perturb_image(Xi,pert,superpixels)
pred = inceptionV3_model.predict(perturbed_img[np.newaxis,:,:,:])
predictions.append(pred)
predictions = np.array(predictions)
predictions.shape
original_image = np.ones(num_superpixels)[np.newaxis,:] #Perturbation with all superpixels enabled
distances = sklearn.metrics.pairwise_distances(perturbations,original_image, metric='cosine').ravel()
distances.shape
kernel_width = 0.25
weights = np.sqrt(np.exp(-(distances**2)/kernel_width**2)) #Kernel function
weights.shape
class_to_explain = top_pred_classes[0]
simpler_model = LinearRegression()
simpler_model.fit(X=perturbations, y=predictions[:,:,class_to_explain], sample_weight=weights)
coeff = simpler_model.coef_[0]
coeff
num_top_features = 4
top_features = np.argsort(coeff)[-num_top_features:]
top_features
mask = np.zeros(num_superpixels)
mask[top_features]= True #Activate top superpixels
skimage.io.imshow(perturb_image(Xi/2+0.5,mask,superpixels) )