Found this interesting new idea from https://github.com/sdoria/SimpleSelfAttention. I have been trying to incorporate attention into another model to get better results, so I decided to give this a try and see if it works.
from fastai.vision import *
from fastai.callbacks.tracker import *
%matplotlib inline
As always we get our images from google. Experimenting with a better way of downloading images, i.e. we feed the url of the google images search result page, this seems to give better photos in the sense that they are more relevant to our query.
from google_images_download import google_images_download
response = google_images_download.googleimagesdownload()
suits = ['spades','hearts','diamonds','clubs']
numbers = ['ace','two','three','four','five','six','seven','eight','nine','ten','jack','queen','king']
args = {'exact_size':"350,350",
'limit':10,
'output_directory':'images',
'silent_mode':True,
'image_directory': '',
'prefix':'ace_spades',
'type':'photo',
}
We use multiprocessing so things are faster.
from multiprocessing import Pool
def f(s):
for n in numbers:
url = f'https://www.google.com/search?q={n}+of+{s}+poker+card&client=safari&source=lnms&tbm=isch&sa=X&ved=0ahUKEwiGwtjpy7viAhUC148KHdc5DuAQ_AUIDigB&biw=1440&bih=718'
args['url'] = url
args['image_directory'] = s
args['prefix'] = f'{n}_{s}'
p = response.download(args)
print(f'Completed download for {s}.')
if __name__ == '__main__':
p = Pool(4)
print(p.map(f, suits))
Check for images that are busted and delete them when found.
##A function to recursively get all jpg files in a directory.
def setify(o): return o if isinstance(o,set) else set(listify(o))
def _get_files(p, fs, extensions=None):
p = Path(p)
res = [p/f for f in fs if not f.startswith('.')
and ((not extensions) or f'.{f.split(".")[-1].lower()}' in extensions)]
return res
#export
def get_files(path, extensions=None, recurse=False, include=None):
path = Path(path)
extensions = setify(extensions)
extensions = {e.lower() for e in extensions}
if recurse:
res = []
for i,(p,d,f) in enumerate(os.walk(path)): # returns (dirpath, dirnames, filenames)
if include is not None and i==0: d[:] = [o for o in d if o in include]
else: d[:] = [o for o in d if not o.startswith('.')]
res += _get_files(p, f, extensions)
return res
else:
f = [o.name for o in os.scandir(path) if o.is_file()]
return _get_files(path, f, extensions)
path = Path('./test/')
res = get_files(path,extensions='.jpg',recurse=True)
from PIL import Image
count=0
for r in res:
try:
im = Image.open(r)
im.verify()
except:
os.remove(r)
count+=1
print(f'Removed {count} files.')
Removed 0 files.
path = Path('./')
src = (ImageList.from_folder(path,extensions=['.jpg','.png'],recurse=True,include='images')
.split_by_rand_pct(seed=42)
.label_from_folder()
)
data = (src
.transform(get_transforms(),size=(224,224))
.databunch(bs=15)
.normalize(imagenet_stats))
data.save()
data=load_data(path)
data.batch_size=20
print(data.classes)
print(data.c)
['clubs', 'diamonds', 'hearts', 'spades'] 4
data.show_batch()
Some images are no good, but it's okay we can move on regardless.
act_fn = nn.ReLU(inplace=True)
def noop(x): return x
class Flatten(nn.Module):
def forward(self, x): return x.view(x.size(0), -1)
def conv(ni, nf, ks=3, stride=1, bias=False):
return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)
def init_cnn(m):
if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)
for l in m.children(): init_cnn(l)
def conv_layer(ni, nf, ks=3, stride=1, zero_bn=False, act=True):
bn = nn.BatchNorm2d(nf)
nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
layers = [conv(ni, nf, ks, stride=stride), bn]
if act: layers.append(act_fn)
return nn.Sequential(*layers)
#Unmodified from https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
def conv1d(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False):
"Create and initialize a `nn.Conv1d` layer with spectral normalization."
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
nn.init.kaiming_normal_(conv.weight)
if bias: conv.bias.data.zero_()
return spectral_norm(conv)
# Adapted from SelfAttention layer at https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
# Inspired by https://arxiv.org/pdf/1805.08318.pdf
class SimpleSelfAttention(nn.Module):
def __init__(self, n_in:int, ks=1):#, n_out:int):
super().__init__()
self.conv = conv1d(n_in, n_in, ks, padding=ks//2, bias=False)
self.gamma = nn.Parameter(tensor([0.]))
def forward(self,x):
size = x.size()
x = x.view(*size[:2],-1)
w = torch.bmm(x.permute(0,2,1).contiguous(),self.conv(x))
o = self.gamma * torch.bmm(x,w) + x
return o.view(*size).contiguous(), w
#unmodified from https://github.com/fastai/fastai/blob/9b9014b8967186dc70c65ca7dcddca1a1232d99d/fastai/vision/models/xresnet.py
def conv(ni, nf, ks=3, stride=1, bias=False):
return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)
def noop(x): return x
def conv_layer(ni, nf, ks=3, stride=1, zero_bn=False, act=True):
bn = nn.BatchNorm2d(nf)
nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
layers = [conv(ni, nf, ks, stride=stride), bn]
if act: layers.append(act_fn)
return nn.Sequential(*layers)
# Modified from https://github.com/fastai/fastai/blob/9b9014b8967186dc70c65ca7dcddca1a1232d99d/fastai/vision/models/xresnet.py
# Added self attention
class ResBlock(nn.Module):
def __init__(self, expansion, ni, nh, stride=1,sa=False):
super().__init__()
nf,ni = nh*expansion,ni*expansion
layers = [conv_layer(ni, nh, 3, stride=stride),
conv_layer(nh, nf, 3, zero_bn=True, act=False)
] if expansion == 1 else [
conv_layer(ni, nh, 1),
conv_layer(nh, nh, 3, stride=stride),
conv_layer(nh, nf, 1, zero_bn=True, act=False)
]
self.sa = sa
self.attn = SimpleSelfAttention(nf,ks=1) if sa else noop
self.convs = nn.Sequential(*layers)
self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False)
self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)
def forward(self, x):
# We only return the attn vector for interpretation
# we do this so that we can still fit this in nn.Sequential
if self.sa:
return act_fn(self.attn(self.convs(x))[0] + self.idconv(self.pool(x)))
else:
return act_fn(self.attn(self.convs(x)) + self.idconv(self.pool(x)))
# Modified from https://github.com/fastai/fastai/blob/9b9014b8967186dc70c65ca7dcddca1a1232d99d/fastai/vision/models/xresnet.py
# Added self attention
class XResNet_sa(nn.Sequential):
@classmethod
def create(cls, expansion, layers, c_in=3, c_out=1000):
nfs = [c_in, (c_in+1)*8, 64, 64]
stem = [conv_layer(nfs[i], nfs[i+1], stride=2 if i==0 else 1)
for i in range(3)]
nfs = [64//expansion,64,128,256,512]
res_layers = [cls._make_layer(expansion, nfs[i], nfs[i+1],
n_blocks=l, stride=1 if i==0 else 2, sa = True if i in[len(layers)-4] else False)
for i,l in enumerate(layers)]
res = cls(
*stem,
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
*res_layers,
nn.AdaptiveAvgPool2d(1), Flatten(),
nn.Linear(nfs[-1]*expansion, c_out),
)
init_cnn(res)
return res
@staticmethod
def _make_layer(expansion, ni, nf, n_blocks, stride, sa = False):
return nn.Sequential(
*[ResBlock(expansion, ni if i==0 else nf, nf, stride if i==0 else 1, sa if i in [n_blocks -1] else False)
for i in range(n_blocks)])
def xresnet50_sa (**kwargs): return XResNet_sa.create(4, [3, 4, 6, 3], **kwargs)
#to help prevent cuda memory errors
try:
del m;gc.collect()
except:
pass
m = xresnet50_sa(c_out=data.c).cuda()
learn = Learner(data, m, metrics=error_rate)
learn.loss_func = FlattenedLoss(LabelSmoothingCrossEntropy)
learn.unfreeze()
lr_find(learn)
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.recorder.plot()
learn.fit_one_cycle(20,1e-3,callbacks=[SaveModelCallback(learn,every='improvement',name='1st_224_cards')])
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 1.280328 | 1.272623 | 0.578431 | 00:15 |
1 | 1.197646 | 1.416744 | 0.617647 | 00:15 |
2 | 1.152464 | 1.533333 | 0.509804 | 00:15 |
3 | 1.174556 | 1.623811 | 0.519608 | 00:14 |
4 | 1.177189 | 2.156969 | 0.558824 | 00:15 |
5 | 1.110141 | 1.626028 | 0.401961 | 00:16 |
6 | 1.056969 | 2.196778 | 0.558824 | 00:15 |
7 | 1.000436 | 1.571214 | 0.470588 | 00:15 |
8 | 0.955225 | 0.908581 | 0.196078 | 00:14 |
9 | 0.881647 | 0.921629 | 0.205882 | 00:15 |
10 | 0.827706 | 0.936471 | 0.205882 | 00:17 |
11 | 0.782261 | 1.796283 | 0.313725 | 00:16 |
12 | 0.758876 | 1.029887 | 0.303922 | 00:14 |
13 | 0.754478 | 0.754220 | 0.127451 | 00:14 |
14 | 0.720812 | 0.746815 | 0.215686 | 00:15 |
15 | 0.676691 | 0.700019 | 0.117647 | 00:16 |
16 | 0.638353 | 0.691785 | 0.107843 | 00:15 |
17 | 0.608052 | 0.673471 | 0.137255 | 00:16 |
18 | 0.585816 | 0.655005 | 0.088235 | 00:16 |
19 | 0.566775 | 0.656597 | 0.107843 | 00:14 |
Better model found at epoch 0 with valid_loss value: 1.2726233005523682. Better model found at epoch 8 with valid_loss value: 0.9085805416107178. Better model found at epoch 13 with valid_loss value: 0.754219651222229. Better model found at epoch 14 with valid_loss value: 0.7468152046203613. Better model found at epoch 15 with valid_loss value: 0.7000190615653992. Better model found at epoch 16 with valid_loss value: 0.6917850375175476. Better model found at epoch 17 with valid_loss value: 0.6734710335731506. Better model found at epoch 18 with valid_loss value: 0.6550053358078003.
learn.load('1st_224_cards');
/home/jupyter/anaconda3/lib/python3.7/site-packages/torch/serialization.py:256: UserWarning: Couldn't retrieve source code for container of type Flatten. It won't be checked for correctness upon loading. "type " + obj.__name__ + ". It won't be checked " /home/jupyter/anaconda3/lib/python3.7/site-packages/torch/serialization.py:256: UserWarning: Couldn't retrieve source code for container of type XResNet_sa. It won't be checked for correctness upon loading. "type " + obj.__name__ + ". It won't be checked " /home/jupyter/anaconda3/lib/python3.7/site-packages/torch/serialization.py:256: UserWarning: Couldn't retrieve source code for container of type ResBlock. It won't be checked for correctness upon loading. "type " + obj.__name__ + ". It won't be checked " /home/jupyter/anaconda3/lib/python3.7/site-packages/torch/serialization.py:256: UserWarning: Couldn't retrieve source code for container of type SimpleSelfAttention. It won't be checked for correctness upon loading. "type " + obj.__name__ + ". It won't be checked "
interp = ClassificationInterpretation.from_learner(learn,ds_type=DatasetType.Valid)
Model gets most of the predictions correct on the validation set, things are looking good.
interp.plot_confusion_matrix()
Most of the errors are due to noisy labels, we can easily forgive the model for making these mistakes.
interp.probs = torch.softmax(interp.probs,dim=-1)
interp.plot_top_losses(9, figsize=(15,11))
Here comes the fun part, we want to find out where exactly does the model 'look' at while calculating a prediction.
# !wget https://cdn4.vectorstock.com/i/1000x1000/50/23/black-ace-of-hearts-vector-23815023.jpg -O test/black_heart.jpg
import copy
model = copy.deepcopy(m)
model=model.cpu()
We can either load an image from disk, or load an image from the validation set.
# x_im = open_image('test/red_spade.jpg')
x_im = data.valid_ds.x[1]
x_im
We convert the image to a tensor that the model accepts as an input.
x = data.one_item(x_im)[0]
Seems to be giving the right prediction.
data.classes
['clubs', 'diamonds', 'hearts', 'spades']
with torch.no_grad():
res = model(x.cpu())
print(data.classes[res.argmax(-1)])
diamonds
We attach a hook to the attention activation from the Self-Attention layer so we can retrieve the activations.
from fastai.callbacks.hooks import *
def hook_attention():
with hook_output(model[4][2].attn) as hook_sa:
res = model(x.cpu())
return hook_sa
h = hook_attention()
attn = h.stored[1]
We calculate the mean weights for each pixel on the N * N attention grid. Then we scale it so its between 0 and 1.
attn = attn.mean(dim=1).transpose(1,0)
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
res = scaler.fit_transform(attn)
Reshape it back to the dims of a square image. (We are assuming we are working with square tensors!)
hw = int(np.sqrt(attn.shape[0]))
attn = attn.view(hw,hw)
sz = list(x_im.shape[-2:])
Voila, the model seems to have its eyes on the right places!
_,ax = plt.subplots(1,1, figsize=(10,10))
x_im.show(ax=ax)
ax.imshow(attn, cmap='viridis', interpolation='nearest',alpha=0.7,extent=(0,*sz[::-1],0))
<matplotlib.image.AxesImage at 0x7ff6a1f8a278>