%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.vision import *
bs = 64
path = untar_data(URLs.PETS) / 'images'
tfms = get_transforms(max_rotate=20, max_zoom=1.3, max_lighting=0.4, max_warp=0.4,
p_affine=1., p_lighting=1.)
doc(get_transforms)
src = ImageItemList.from_folder(path).random_split_by_pct(0.2, seed=2)
def get_data(size, bs, padding_mode='reflection'):
return (src.label_from_re(r'([^/]+)_\d+.jpg$')
.transform(tfms, size=size, padding_mode=padding_mode)
.databunch(bs=bs).normalize(imagenet_stats))
data = get_data(224, bs, 'zeros')
def _plot(i, j, ax):
x, y = data.train_ds[3]
x.show(ax, y=y)
plot_multi(_plot, 3, 3, figsize=(8,8))
data = get_data(224, bs)
plot_multi(_plot, 3, 3, figsize=(8,8))
gc.collect()
14593
learn = create_cnn(data, models.resnet34, metrics=error_rate, bn_final=True)
learn.lr_find()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.recorder.plot()
learn.fit_one_cycle(3, slice(1e-2), pct_start=0.8)
epoch | train_loss | valid_loss | error_rate |
---|---|---|---|
1 | 2.431377 | 1.230201 | 0.290934 |
2 | 1.439160 | 0.372741 | 0.102842 |
3 | 0.878912 | 0.300877 | 0.087957 |
learn.lr_find()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.recorder.plot(skip_start=0)
learn.unfreeze()
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-3), pct_start=0.8)
epoch | train_loss | valid_loss | error_rate |
---|---|---|---|
1 | 0.694553 | 0.305744 | 0.079161 |
2 | 0.650280 | 0.301226 | 0.070365 |
data = get_data(352, bs)
learn.data = data
learn.lr_find()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.recorder.plot(skip_start=0)
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-4))
epoch | train_loss | valid_loss | error_rate |
---|---|---|---|
1 | 0.599940 | 0.277238 | 0.057510 |
2 | 0.561108 | 0.279816 | 0.056157 |
learn.save('352')
learn.save('352')
data = get_data(352,16)
learn = create_cnn(data, models.resnet34, metrics=error_rate, bn_final=True).load('352')
idx = 33
x, y = data.valid_ds[idx]
x.show()
data.valid_ds.y[idx]
Category Ragdoll
k = tensor([
[0. , -5/3, 1],
[-5/3, -5/3, 1],
[1. , 1 , 1],
]).expand(1, 3, 3, 3) / 6
k
tensor([[[[ 0.0000, -0.2778, 0.1667], [-0.2778, -0.2778, 0.1667], [ 0.1667, 0.1667, 0.1667]], [[ 0.0000, -0.2778, 0.1667], [-0.2778, -0.2778, 0.1667], [ 0.1667, 0.1667, 0.1667]], [[ 0.0000, -0.2778, 0.1667], [-0.2778, -0.2778, 0.1667], [ 0.1667, 0.1667, 0.1667]]]])
k.shape
torch.Size([1, 3, 3, 3])
from fastai.callbacks.hooks import *
t = data.valid_ds[idx][0].data
t.shape
torch.Size([3, 352, 352])
t[None].shape
torch.Size([1, 3, 352, 352])
edge = F.conv2d(t[None], k)
show_image(edge[0], figsize=(5,5))
<matplotlib.axes._subplots.AxesSubplot at 0x7f52a854b278>
data.c
37
learn.model
Sequential( (0): Sequential( (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace) (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (4): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (5): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (3): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (6): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (3): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (4): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (5): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (7): Sequential( (0): BasicBlock( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ) (1): Sequential( (0): AdaptiveConcatPool2d( (ap): AdaptiveAvgPool2d(output_size=1) (mp): AdaptiveMaxPool2d(output_size=1) ) (1): Lambda() (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): Dropout(p=0.25) (4): Linear(in_features=1024, out_features=512, bias=True) (5): ReLU(inplace) (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (7): Dropout(p=0.5) (8): Linear(in_features=512, out_features=37, bias=True) (9): BatchNorm1d(37, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True) ) )
learn.summary()
Input Size override by Learner.data.train_dl Input Size passed in: 16 ================================================================================ Layer (type) Output Shape Param # ================================================================================ Conv2d [16, 64, 176, 176] 9408 ________________________________________________________________________________ BatchNorm2d [16, 64, 176, 176] 128 ________________________________________________________________________________ ReLU [16, 64, 176, 176] 0 ________________________________________________________________________________ MaxPool2d [16, 64, 88, 88] 0 ________________________________________________________________________________ Conv2d [16, 64, 88, 88] 36864 ________________________________________________________________________________ BatchNorm2d [16, 64, 88, 88] 128 ________________________________________________________________________________ ReLU [16, 64, 88, 88] 0 ________________________________________________________________________________ Conv2d [16, 64, 88, 88] 36864 ________________________________________________________________________________ BatchNorm2d [16, 64, 88, 88] 128 ________________________________________________________________________________ Conv2d [16, 64, 88, 88] 36864 ________________________________________________________________________________ BatchNorm2d [16, 64, 88, 88] 128 ________________________________________________________________________________ ReLU [16, 64, 88, 88] 0 ________________________________________________________________________________ Conv2d [16, 64, 88, 88] 36864 ________________________________________________________________________________ BatchNorm2d [16, 64, 88, 88] 128 ________________________________________________________________________________ Conv2d [16, 64, 88, 88] 36864 ________________________________________________________________________________ BatchNorm2d [16, 64, 88, 88] 128 ________________________________________________________________________________ ReLU [16, 64, 88, 88] 0 ________________________________________________________________________________ Conv2d [16, 64, 88, 88] 36864 ________________________________________________________________________________ BatchNorm2d [16, 64, 88, 88] 128 ________________________________________________________________________________ Conv2d [16, 128, 44, 44] 73728 ________________________________________________________________________________ BatchNorm2d [16, 128, 44, 44] 256 ________________________________________________________________________________ ReLU [16, 128, 44, 44] 0 ________________________________________________________________________________ Conv2d [16, 128, 44, 44] 147456 ________________________________________________________________________________ BatchNorm2d [16, 128, 44, 44] 256 ________________________________________________________________________________ Conv2d [16, 128, 44, 44] 8192 ________________________________________________________________________________ BatchNorm2d [16, 128, 44, 44] 256 ________________________________________________________________________________ Conv2d [16, 128, 44, 44] 147456 ________________________________________________________________________________ BatchNorm2d [16, 128, 44, 44] 256 ________________________________________________________________________________ ReLU [16, 128, 44, 44] 0 ________________________________________________________________________________ Conv2d [16, 128, 44, 44] 147456 ________________________________________________________________________________ BatchNorm2d [16, 128, 44, 44] 256 ________________________________________________________________________________ Conv2d [16, 128, 44, 44] 147456 ________________________________________________________________________________ BatchNorm2d [16, 128, 44, 44] 256 ________________________________________________________________________________ ReLU [16, 128, 44, 44] 0 ________________________________________________________________________________ Conv2d [16, 128, 44, 44] 147456 ________________________________________________________________________________ BatchNorm2d [16, 128, 44, 44] 256 ________________________________________________________________________________ Conv2d [16, 128, 44, 44] 147456 ________________________________________________________________________________ BatchNorm2d [16, 128, 44, 44] 256 ________________________________________________________________________________ ReLU [16, 128, 44, 44] 0 ________________________________________________________________________________ Conv2d [16, 128, 44, 44] 147456 ________________________________________________________________________________ BatchNorm2d [16, 128, 44, 44] 256 ________________________________________________________________________________ Conv2d [16, 256, 22, 22] 294912 ________________________________________________________________________________ BatchNorm2d [16, 256, 22, 22] 512 ________________________________________________________________________________ ReLU [16, 256, 22, 22] 0 ________________________________________________________________________________ Conv2d [16, 256, 22, 22] 589824 ________________________________________________________________________________ BatchNorm2d [16, 256, 22, 22] 512 ________________________________________________________________________________ Conv2d [16, 256, 22, 22] 32768 ________________________________________________________________________________ BatchNorm2d [16, 256, 22, 22] 512 ________________________________________________________________________________ Conv2d [16, 256, 22, 22] 589824 ________________________________________________________________________________ BatchNorm2d [16, 256, 22, 22] 512 ________________________________________________________________________________ ReLU [16, 256, 22, 22] 0 ________________________________________________________________________________ Conv2d [16, 256, 22, 22] 589824 ________________________________________________________________________________ BatchNorm2d [16, 256, 22, 22] 512 ________________________________________________________________________________ Conv2d [16, 256, 22, 22] 589824 ________________________________________________________________________________ BatchNorm2d [16, 256, 22, 22] 512 ________________________________________________________________________________ ReLU [16, 256, 22, 22] 0 ________________________________________________________________________________ Conv2d [16, 256, 22, 22] 589824 ________________________________________________________________________________ BatchNorm2d [16, 256, 22, 22] 512 ________________________________________________________________________________ Conv2d [16, 256, 22, 22] 589824 ________________________________________________________________________________ BatchNorm2d [16, 256, 22, 22] 512 ________________________________________________________________________________ ReLU [16, 256, 22, 22] 0 ________________________________________________________________________________ Conv2d [16, 256, 22, 22] 589824 ________________________________________________________________________________ BatchNorm2d [16, 256, 22, 22] 512 ________________________________________________________________________________ Conv2d [16, 256, 22, 22] 589824 ________________________________________________________________________________ BatchNorm2d [16, 256, 22, 22] 512 ________________________________________________________________________________ ReLU [16, 256, 22, 22] 0 ________________________________________________________________________________ Conv2d [16, 256, 22, 22] 589824 ________________________________________________________________________________ BatchNorm2d [16, 256, 22, 22] 512 ________________________________________________________________________________ Conv2d [16, 256, 22, 22] 589824 ________________________________________________________________________________ BatchNorm2d [16, 256, 22, 22] 512 ________________________________________________________________________________ ReLU [16, 256, 22, 22] 0 ________________________________________________________________________________ Conv2d [16, 256, 22, 22] 589824 ________________________________________________________________________________ BatchNorm2d [16, 256, 22, 22] 512 ________________________________________________________________________________ Conv2d [16, 512, 11, 11] 1179648 ________________________________________________________________________________ BatchNorm2d [16, 512, 11, 11] 1024 ________________________________________________________________________________ ReLU [16, 512, 11, 11] 0 ________________________________________________________________________________ Conv2d [16, 512, 11, 11] 2359296 ________________________________________________________________________________ BatchNorm2d [16, 512, 11, 11] 1024 ________________________________________________________________________________ Conv2d [16, 512, 11, 11] 131072 ________________________________________________________________________________ BatchNorm2d [16, 512, 11, 11] 1024 ________________________________________________________________________________ Conv2d [16, 512, 11, 11] 2359296 ________________________________________________________________________________ BatchNorm2d [16, 512, 11, 11] 1024 ________________________________________________________________________________ ReLU [16, 512, 11, 11] 0 ________________________________________________________________________________ Conv2d [16, 512, 11, 11] 2359296 ________________________________________________________________________________ BatchNorm2d [16, 512, 11, 11] 1024 ________________________________________________________________________________ Conv2d [16, 512, 11, 11] 2359296 ________________________________________________________________________________ BatchNorm2d [16, 512, 11, 11] 1024 ________________________________________________________________________________ ReLU [16, 512, 11, 11] 0 ________________________________________________________________________________ Conv2d [16, 512, 11, 11] 2359296 ________________________________________________________________________________ BatchNorm2d [16, 512, 11, 11] 1024 ________________________________________________________________________________ AdaptiveAvgPool2d [16, 512, 1, 1] 0 ________________________________________________________________________________ AdaptiveMaxPool2d [16, 512, 1, 1] 0 ________________________________________________________________________________ Lambda [16, 1024] 0 ________________________________________________________________________________ BatchNorm1d [16, 1024] 2048 ________________________________________________________________________________ Dropout [16, 1024] 0 ________________________________________________________________________________ Linear [16, 512] 524800 ________________________________________________________________________________ ReLU [16, 512] 0 ________________________________________________________________________________ BatchNorm1d [16, 512] 1024 ________________________________________________________________________________ Dropout [16, 512] 0 ________________________________________________________________________________ Linear [16, 37] 18981 ________________________________________________________________________________ BatchNorm1d [16, 37] 74 ________________________________________________________________________________ Total params: 21831599
m = learn.model.eval()
xb, _ = data.one_item(x)
xb_im = Image(data.denorm(xb)[0])
xb = xb.cuda()
from fastai.callbacks.hooks import *
def hooked_backward(cat=y):
with hook_output(m[0]) as hook_a:
with hook_output(m[0], grad=True) as hook_g:
preds = m(xb)
preds[0, int(cat)].backward()
return hook_a, hook_g
hook_a, hook_g = hooked_backward() # hook activations, hook gradients
acts = hook_a.stored[0].cpu() # class attribute self.stored. move tensor from GPU back to CPU
acts.shape
torch.Size([512, 11, 11])
avg_acts = acts.mean(0)
avg_acts.shape
torch.Size([11, 11])
def show_heatmap(hm):
_, ax = plt.subplots()
xb_im.show(ax)
ax.imshow(hm, alpha=0.6, extent=(0,352,352,0),
interpolation='bilinear', cmap='magma')
show_heatmap(avg_acts)
# Sanity check
type(hook_g.stored), len(hook_g.stored), type(hook_g.stored[0]), hook_g.stored[0].shape, hook_g.stored[0][0].shape
(list, 1, torch.Tensor, torch.Size([1, 512, 11, 11]), torch.Size([512, 11, 11]))
# Sanity check
grad.mean(1).shape
torch.Size([512, 11])
grad = hook_g.stored[0][0].cpu()
grad_chan = grad.mean(1).mean(1) # mean over axis 1: [512, 11, 11] -> [512, 11] -> [512]
grad.shape, grad_chan.shape
(torch.Size([512, 11, 11]), torch.Size([512]))
# Sanity check
print(grad_chan[...,None,None].shape) # None is effectively adding a new dim to tensor
print(acts.shape)
torch.Size([512, 1, 1]) torch.Size([512, 11, 11])
# Sanity check
(acts*grad_chan[...,None,None]).shape
torch.Size([512, 11, 11])
mult = (acts*grad_chan[...,None,None]).mean(0)
# Sanity check
mult.shape
torch.Size([11, 11])
show_heatmap(mult)
fn = path / '../../other/ragdoll_and_dog.jpg'
x = open_image(fn); x
xb, _ = data.one_item(x)
xb_im = Image(data.denorm(xb)[0])
xb = xb.cuda()
hook_a, hook_g = hooked_backward()
acts = hook_a.stored[0].cpu()
grad = hook_g.stored[0][0].cpu()
grad_chan = grad.mean(1).mean(1)
mult = (acts*grad_chan[...,None,None]).mean(0)
show_heatmap(mult)
data.classes[0]
'Abyssinian'
hook_a, hook_g = hooked_backward(0)
acts = hook_a.stored[0].cpu()
grad = hook_g.stored[0][0].cpu()
grad_chan = grad.mean(1).mean(1)
mult = (acts*grad_chan[...,None,None]).mean(0)
show_heatmap(mult)