$\newcommand{\xv}{\mathbf{x}} \newcommand{\Xv}{\mathbf{X}} \newcommand{\yv}{\mathbf{y}} \newcommand{\zv}{\mathbf{z}} \newcommand{\av}{\mathbf{a}} \newcommand{\Wv}{\mathbf{W}} \newcommand{\wv}{\mathbf{w}} \newcommand{\tv}{\mathbf{t}} \newcommand{\Tv}{\mathbf{T}} \newcommand{\muv}{\boldsymbol{\mu}} \newcommand{\sigmav}{\boldsymbol{\sigma}} \newcommand{\phiv}{\boldsymbol{\phi}} \newcommand{\Phiv}{\boldsymbol{\Phi}} \newcommand{\Sigmav}{\boldsymbol{\Sigma}} \newcommand{\Lambdav}{\boldsymbol{\Lambda}} \newcommand{\half}{\frac{1}{2}} \newcommand{\argmax}[1]{\underset{#1}{\operatorname{argmax}}} \newcommand{\argmin}[1]{\underset{#1}{\operatorname{argmin}}}$

**Interpreting What a Neural Network Has Learned**

Explainable Artificial Intelligence (XAI): Concepts, taxonomies, opportunities and challenges toward responsible AI, Arrieta, et al., Information Fusion, Volume 58, June 2020, Pages 82-115

"Given a certain audience, explainability refers to the details and reasons a model gives to make its functioning clear or easy to understand."

Here we will examine what the hidden units in a convolutional neural network have learned. This is most intuitive if we focus on classification problems involving images.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import pandas as pd
import os
In [2]:
from A6mysolution import *
In [3]:
# for regression problem
def rmse(a, b):
    return np.sqrt(np.mean((a - b)**2))

# for classification problem
def percent_correct(a, b):
    return 100 * np.mean(a == b)

# for classification problem
def confusion_matrix(Y_classes, T):
    class_names = np.unique(T)
    table = []
    for true_class in class_names:
        row = []
        for Y_class in class_names:
            row.append(100 * np.mean(Y_classes[T == true_class] == Y_class))
        table.append(row)
    conf_matrix = pd.DataFrame(table, index=class_names, columns=class_names)
    conf_matrix.style.background_gradient(cmap='Blues').format("{:.1f}")
    print(f'Percent Correct is {percent_correct(Y_classes, T)}')
    return conf_matrix
In [ ]:
 
In [4]:
def makeImages(nEach):
    images = np.zeros((nEach * 2, 1, 20, 20))  # nSamples, nChannels, rows, columns
    radii = 3 + np.random.randint(10 - 5, size=(nEach * 2, 1))
    centers = np.zeros((nEach * 2, 2))
    for i in range(nEach * 2):
        r = radii[i, 0]
        centers[i, :] = r + 1 + np.random.randint(18 - 2 * r, size=(1, 2))
        x = int(centers[i, 0])
        y = int(centers[i, 1])
        if i < nEach:
            # squares
            images[i, 0, x - r:x + r, y + r] = 1.0
            images[i, 0, x - r:x + r, y - r] = 1.0
            images[i, 0, x - r, y - r:y + r] = 1.0
            images[i, 0, x + r, y - r:y + r + 1] = 1.0
        else:
            # diamonds
            images[i, 0, range(x - r, x), range(y, y + r)] = 1.0
            images[i, 0, range(x - r, x), range(y, y - r, -1)] = 1.0
            images[i, 0, range(x, x + r + 1), range(y + r, y - 1, -1)] = 1.0
            images[i, 0, range(x, x + r), range(y - r, y)] = 1.0
            # images += np.random.randn(*images.shape) * 0.5
        T = np.zeros((nEach * 2, 1))
        T[nEach:] = 1
    return images, T

nEach = 1000
X, T = makeImages(nEach)
X = X.reshape(X.shape[0], -1)
print(X.shape, T.shape)

Xtest, Ttest = makeImages(nEach)
Xtest = Xtest.reshape(Xtest.shape[0], -1)

plt.plot(T);
(2000, 400) (2000, 1)
In [5]:
plt.imshow(-X[-1, :].reshape(20, 20), cmap='gray')
plt.xticks([])
plt.yticks([])
Out[5]:
([], [])
In [6]:
plt.figure(figsize=(10, 3))

for i in range(10):
    plt.subplot(2, 10, i + 1)
    plt.imshow(-X[i, :].reshape(20,20), cmap='gray')
    plt.xticks([])
    plt.yticks([])

    plt.subplot(2, 10, i + 11)
    plt.imshow(-X[-i, :].reshape(20,20), cmap='gray')
    plt.xticks([])
    plt.yticks([])
In [7]:
nnet, learning_curve = train_for_classification(X, T, hidden_layers=[10],
                                                n_epochs=500, learning_rate=0.01)
plt.plot(learning_curve);
In [8]:
nnet
Out[8]:
Sequential(
  (0): Linear(in_features=400, out_features=10, bias=True)
  (1): Tanh()
  (2): Linear(in_features=10, out_features=2, bias=True)
  (3): LogSoftmax()
)
In [9]:
Y = use(nnet, X)
Ytest = use(nnet, Xtest)
Y.shape
Out[9]:
(2000, 2)
In [10]:
plt.subplot(2, 1, 1)
plt.plot(Y)
plt.subplot(2, 1, 2)
plt.plot(Ytest)
Out[10]:
[<matplotlib.lines.Line2D at 0x7fa89861d370>,
 <matplotlib.lines.Line2D at 0x7fa89861d460>]
In [11]:
plt.plot(np.exp(Y))
Out[11]:
[<matplotlib.lines.Line2D at 0x7fa898582b50>,
 <matplotlib.lines.Line2D at 0x7fa898582c40>]
In [12]:
Y_classes = np.argmax(Y, axis=1).reshape(-1, 1)  # To keep 2-dimensional shape
plt.plot(Y_classes, 'o', label='Predicted')
plt.plot(T + 0.1, 'o', label='Target')
plt.legend();
In [13]:
Y_classes_test = np.argmax(Ytest, axis=1).reshape(-1, 1)  # To keep 2-dimensional shape
plt.plot(Y_classes_test, 'o', label='Predicted')
plt.plot(T + 0.1, 'o', label='Target')
plt.legend();
In [14]:
Y.shape
Out[14]:
(2000, 2)
In [15]:
confusion_matrix(Y_classes_test, Ttest)
Percent Correct is 99.05000000000001
Out[15]:
0.0 1.0
0.0 99.1 0.9
1.0 1.0 99.0
In [16]:
def forward_all_layers(nnet, X):
    
    X = torch.from_numpy(X).float()
    Ys = [X]
    for layer in nnet:
        Ys.append(layer(Ys[-1]))
        
    Ys = [Y.detach().numpy() for Y in Ys]    
    return Ys 
In [17]:
Y_square = forward_all_layers(nnet, X[:10, :])
Y_diamond = forward_all_layers(nnet, X[-10:, :])
In [18]:
nnet
Out[18]:
Sequential(
  (0): Linear(in_features=400, out_features=10, bias=True)
  (1): Tanh()
  (2): Linear(in_features=10, out_features=2, bias=True)
  (3): LogSoftmax()
)
In [19]:
len(Y_square)
Out[19]:
5
In [20]:
Y_square[0].shape
Out[20]:
(10, 400)
In [21]:
Y_square[1].shape
Out[21]:
(10, 10)
In [22]:
plt.plot(Y_square[1]);
In [23]:
plt.plot(Y_square[2]);
In [24]:
both = np.vstack((Y_square[2], Y_diamond[2]))
In [25]:
plt.plot(both);
In [26]:
plt.figure(figsize=(15, 3))
for unit in range(10):
    plt.subplot(1, 10, unit + 1)
    plt.plot(both[:, unit])
plt.tight_layout()
In [27]:
plt.plot(both[:, 9])
Out[27]:
[<matplotlib.lines.Line2D at 0x7fa898a4fcd0>]
In [28]:
nnet
Out[28]:
Sequential(
  (0): Linear(in_features=400, out_features=10, bias=True)
  (1): Tanh()
  (2): Linear(in_features=10, out_features=2, bias=True)
  (3): LogSoftmax()
)
In [29]:
nnet[0].parameters()
Out[29]:
<generator object Module.parameters at 0x7fa898a1eac0>
In [30]:
list(nnet[0].parameters())
Out[30]:
[Parameter containing:
 tensor([[-0.0235,  0.0183, -0.0043,  ..., -0.0283, -0.0472,  0.0493],
         [-0.0305,  0.0305,  0.0375,  ...,  0.0076, -0.0254, -0.0216],
         [-0.0205, -0.0357, -0.0086,  ..., -0.0464,  0.0402,  0.0307],
         ...,
         [-0.0139,  0.0290,  0.0346,  ..., -0.0098,  0.0007,  0.0092],
         [ 0.0226,  0.0153,  0.0103,  ..., -0.0146,  0.0159,  0.0084],
         [ 0.0174,  0.0253, -0.0183,  ...,  0.0091,  0.0213,  0.0307]],
        requires_grad=True),
 Parameter containing:
 tensor([-0.8432, -0.7838, -0.5426, -0.7839,  0.6757, -0.6754, -0.7834,  0.7684,
          0.8641,  0.5593], requires_grad=True)]
In [31]:
W = list(nnet[0].parameters())[0]
W = W.detach().numpy()
W.shape
Out[31]:
(10, 400)
In [32]:
W = W.T
W.shape
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-32-4b76fe59cac6> in <module>
----> 1 W = W.TRdYlGn
      2 W.shape

AttributeError: 'numpy.ndarray' object has no attribute 'TRdYlGn'
In [ ]:
plt.plot(W);
In [ ]:
plt.plot(W[:, 0])
In [ ]:
plt.imshow(W[:, 0].reshape(20, 20), cmap='RdYlGn')
plt.colorbar()
In [ ]:
plt.figure(figsize=(15, 3))

for i in range(10):
    plt.subplot(2, 10, i + 1)
    plt.imshow(W[:, i].reshape(20,20), cmap='RdYlGn')
    plt.xticks([])
    plt.yticks([])
    plt.colorbar()
    
    plt.subplot(2, 10, i + 11)
    plt.plot(both[:, i])
In [ ]:
X.shape
In [ ]:
plt.imshow(X[4,:].reshape(20, 20), cmap='gray')

Let's automate these steps in a function, so we can try different numbers of hidden units and layers.

In [33]:
nnet
Out[33]:
Sequential(
  (0): Linear(in_features=400, out_features=10, bias=True)
  (1): Tanh()
  (2): Linear(in_features=10, out_features=2, bias=True)
  (3): LogSoftmax()
)
In [34]:
Wout = list(nnet[2].parameters())[0]
Wout = Wout.detach().numpy()
Wout = Wout.T
Wout.shape
Out[34]:
(10, 2)
In [35]:
def run_again(hiddens):
    
    nnet, learning_curve = train_for_classification(X, T, hidden_layers=hiddens,
                                                    n_epochs=1000, learning_rate=0.01)
    plt.figure()
    plt.plot(learning_curve)

    Y_square = forward_all_layers(nnet, X[:10, :])
    Y_diamond = forward_all_layers(nnet, X[-10:, :])
    both = np.vstack((Y_square[2], Y_diamond[2]))
    
    W = list(nnet[0].parameters())[0]
    W = W.detach().numpy()
    W = W.T
    
    Wout = list(nnet[2].parameters())[0]
    Wout = Wout.detach().numpy()
    Wout = Wout.T

    plt.figure(figsize=(15, 3))
    
    n_units = hiddens[0]

    size = int(np.sqrt(X.shape[1]))
    
    for i in range(n_units):
        plt.subplot(2, n_units, i + 1)
        plt.imshow(W[:, i].reshape(size, size), cmap='RdYlGn')
        plt.colorbar()
        plt.xticks([])
        plt.yticks([])

        plt.subplot(2, n_units, i + 1 + n_units)
        plt.plot(both[:, i])
        plt.title(f'{Wout[i,0]:.1f},{Wout[i,1]:.1f}')
        
    Y = use(nnet, X)
    Y_classes = np.argmax(Y, axis=1).reshape(-1, 1)
    print(confusion_matrix(Y_classes, T))
    
    Ytest = use(nnet, Xtest)
    Y_classes_test = np.argmax(Ytest, axis=1).reshape(-1, 1)
    print(confusion_matrix(Y_classes_test, Ttest))
In [36]:
run_again([10])
Percent Correct is 100.0
       0.0    1.0
0.0  100.0    0.0
1.0    0.0  100.0
Percent Correct is 99.45
      0.0   1.0
0.0  99.8   0.2
1.0   0.9  99.1
In [ ]:
 
In [37]:
if os.path.isfile('small_mnist.npz'):
    print('Reading data from \'small_mnist.npz\'.')
    small_mnist = np.load('small_mnist.npz')
else:
    import shlex
    import subprocess
    print('Downloading small_mnist.npz from CS545 site.')
    cmd = 'curl "https://www.cs.colostate.edu/~anderson/cs545/notebooks/small_mnist.npz" -o "small_mnist.npz"'
    subprocess.call(shlex.split(cmd))
    small_mnist = np.load('small_mnist.npz')


X = small_mnist['X']
T = small_mnist['T']

X.shape, T.shape
Reading data from 'small_mnist.npz'.
Out[37]:
((1000, 784), (1000, 1))
In [38]:
plt.imshow(-X[0, :].reshape(28, 28), cmap='gray')
Out[38]:
<matplotlib.image.AxesImage at 0x7fa87459ccd0>

Randomly partition the data into 80% for training and 20% for testing, using the following code cells.

In [39]:
n_samples = X.shape[0]
n_train = int(n_samples * 0.6)
rows = np.arange(n_samples)
np.random.shuffle(rows)

Xtrain = X[rows[:n_train], :]
Ttrain = T[rows[:n_train], :]
Xtest = X[rows[n_train:], :]
Ttest = T[rows[n_train:], :]
In [59]:
def run_again_mnist(hiddens):
    
    nnet, learning_curve = train_for_classification(Xtrain, Ttrain, hidden_layers=hiddens,
                                                    n_epochs=1000, learning_rate=0.01)
    plt.figure()
    plt.plot(learning_curve)

    Y_square = forward_all_layers(nnet, X[:10, :])
    Y_diamond = forward_all_layers(nnet, X[-10:, :])
    both = np.vstack((Y_square[2], Y_diamond[2]))
    
    W = list(nnet[0].parameters())[0]
    W = W.detach().numpy()
    W = W.T
    
    Wout = list(nnet[2].parameters())[0]
    Wout = Wout.detach().numpy()
    Wout = Wout.T

    plt.figure(figsize=(15, 15))
    
    n_units = hiddens[0]

    size = int(np.sqrt(X.shape[1]))
    
    n_rows = int(np.sqrt(n_units) + 1)
    for i in range(n_units):
        plt.subplot(n_rows, n_rows, i + 1)
        plt.imshow(W[:, i].reshape(size, size), cmap='RdYlGn')
        plt.colorbar()
        plt.xticks([])
        plt.yticks([])
        
    Y = use(nnet, Xtrain)
    Y_classes = np.argmax(Y, axis=1).reshape(-1, 1)
    display(confusion_matrix(Y_classes, Ttrain))
    
    Ytest = use(nnet, Xtest)
    Y_classes_test = np.argmax(Ytest, axis=1).reshape(-1, 1)
    display(confusion_matrix(Y_classes_test, Ttest))
In [60]:
run_again_mnist([20, 20, 20])
Percent Correct is 100.0
0 1 2 3 4 5 6 7 8 9
0 100.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1 0.0 100.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2 0.0 0.0 100.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
3 0.0 0.0 0.0 100.0 0.0 0.0 0.0 0.0 0.0 0.0
4 0.0 0.0 0.0 0.0 100.0 0.0 0.0 0.0 0.0 0.0
5 0.0 0.0 0.0 0.0 0.0 100.0 0.0 0.0 0.0 0.0
6 0.0 0.0 0.0 0.0 0.0 0.0 100.0 0.0 0.0 0.0
7 0.0 0.0 0.0 0.0 0.0 0.0 0.0 100.0 0.0 0.0
8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 100.0 0.0
9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 100.0
Percent Correct is 84.25
0 1 2 3 4 5 6 7 8 9
0 87.234043 0.000000 0.000000 2.127660 0.000000 10.638298 0.000000 0.000000 0.000000 0.000000
1 0.000000 93.617021 2.127660 0.000000 0.000000 2.127660 0.000000 0.000000 0.000000 2.127660
2 0.000000 2.777778 91.666667 0.000000 0.000000 0.000000 2.777778 0.000000 2.777778 0.000000
3 0.000000 0.000000 0.000000 79.069767 0.000000 11.627907 0.000000 4.651163 4.651163 0.000000
4 0.000000 0.000000 0.000000 0.000000 85.714286 0.000000 5.714286 2.857143 0.000000 5.714286
5 4.347826 0.000000 2.173913 4.347826 0.000000 76.086957 0.000000 0.000000 13.043478 0.000000
6 2.325581 0.000000 0.000000 0.000000 2.325581 0.000000 95.348837 0.000000 0.000000 0.000000
7 2.777778 2.777778 2.777778 0.000000 0.000000 0.000000 0.000000 86.111111 0.000000 5.555556
8 0.000000 3.030303 3.030303 9.090909 0.000000 3.030303 3.030303 9.090909 69.696970 0.000000
9 0.000000 0.000000 2.941176 0.000000 8.823529 2.941176 0.000000 8.823529 2.941176 73.529412
In [ ]: