#hide #skip ! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab #all_slow from fastai.data.external import untar_data,URLs from fastai.data.transforms import get_image_files path = untar_data(URLs.PETS) files = get_image_files(path/"images") files[0] import PIL img = PIL.Image.open(files[0]) img import torch import numpy as np def open_image(fname, size=224): img = PIL.Image.open(fname).convert('RGB') img = img.resize((size, size)) t = torch.Tensor(np.array(img)) return t.permute(2,0,1).float()/255.0 open_image(files[0]).shape import re def label_func(fname): return re.match(r'^(.*)_\d+.jpg$', fname.name).groups()[0] label_func(files[0]) labels = list(set(files.map(label_func))) len(labels) lbl2files = {l: [f for f in files if label_func(f) == l] for l in labels} import random class SiameseDataset(torch.utils.data.Dataset): def __init__(self, files, is_valid=False): self.files,self.is_valid = files,is_valid if is_valid: self.files2 = [self._draw(f) for f in files] def __getitem__(self, i): file1 = self.files[i] (file2,same) = self.files2[i] if self.is_valid else self._draw(file1) img1,img2 = open_image(file1),open_image(file2) return (img1, img2, torch.Tensor([same]).squeeze()) def __len__(self): return len(self.files) def _draw(self, f): same = random.random() < 0.5 cls = label_func(f) if not same: cls = random.choice([l for l in labels if l != cls]) return random.choice(lbl2files[cls]),same idxs = np.random.permutation(range(len(files))) cut = int(0.8 * len(files)) train_files = files[idxs[:cut]] valid_files = files[idxs[cut:]] train_ds = SiameseDataset(train_files) valid_ds = SiameseDataset(valid_files, is_valid=True) from fastai.data.core import DataLoaders dls = DataLoaders.from_dsets(train_ds, valid_ds) b = dls.one_batch() dls = dls.cuda() from fastai.vision.all import * class SiameseTransform(Transform): def __init__(self, files, is_valid=False): self.files,self.is_valid = files,is_valid if is_valid: self.files2 = [self._draw(f) for f in files] def encodes(self, i): file1 = self.files[i] (file2,same) = self.files2[i] if self.is_valid else self._draw(file1) img1,img2 = open_image(file1),open_image(file2) return (TensorImage(img1), TensorImage(img2), torch.Tensor([same]).squeeze()) def _draw(self, f): same = random.random() < 0.5 cls = label_func(f) if not same: cls = random.choice([l for l in labels if l != cls]) return random.choice(lbl2files[cls]),same train_tl= TfmdLists(range(len(train_files)), SiameseTransform(train_files)) valid_tl= TfmdLists(range(len(valid_files)), SiameseTransform(valid_files, is_valid=True)) dls = DataLoaders.from_dsets(train_tl, valid_tl, after_batch=[Normalize.from_stats(*imagenet_stats), *aug_transforms()]) dls = dls.cuda() class SiameseImage(fastuple): def show(self, ctx=None, **kwargs): if len(self) > 2: img1,img2,similarity = self else: img1,img2 = self similarity = 'Undetermined' if not isinstance(img1, Tensor): if img2.size != img1.size: img2 = img2.resize(img1.size) t1,t2 = tensor(img1),tensor(img2) t1,t2 = t1.permute(2,0,1),t2.permute(2,0,1) else: t1,t2 = img1,img2 line = t1.new_zeros(t1.shape[0], t1.shape[1], 10) return show_image(torch.cat([t1,line,t2], dim=2), title=similarity, ctx=ctx, **kwargs) img = PILImage.create(files[0]) img1 = PILImage.create(files[1]) s = SiameseImage(img, img1, False) s.show(); tst = Resize(224)(s) tst = ToTensor()(tst) tst.show(); class SiameseTransform(Transform): def __init__(self, files, splits): self.splbl2files = [{l: [f for f in files[splits[i]] if label_func(f) == l] for l in labels} for i in range(2)] self.valid = {f: self._draw(f,1) for f in files[splits[1]]} def encodes(self, f): f2,same = self.valid.get(f, self._draw(f,0)) img1,img2 = PILImage.create(f),PILImage.create(f2) return SiameseImage(img1, img2, same) def _draw(self, f, split=0): same = random.random() < 0.5 cls = label_func(f) if not same: cls = random.choice(L(l for l in labels if l != cls)) return random.choice(self.splbl2files[split][cls]),same splits = RandomSplitter()(files) tfm = SiameseTransform(files, splits) valids = [v[0] for k,v in tfm.valid.items()] assert not [v for v in valids if v in files[splits[0]]] tls = TfmdLists(files, tfm, splits=splits) show_at(tls.valid, 0) dls = tls.dataloaders(after_item=[Resize(224), ToTensor], after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)]) b = dls.one_batch() type(b) @typedispatch def show_batch(x:SiameseImage, y, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs): if figsize is None: figsize = (ncols*6, max_n//ncols * 3) if ctxs is None: ctxs = get_grid(min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize) for i,ctx in enumerate(ctxs): SiameseImage(x[0][i], x[1][i], ['Not similar','Similar'][x[2][i].item()]).show(ctx=ctx) b = dls.one_batch() dls._types dls.show_batch() class ImageTuple(fastuple): @classmethod def create(cls, fns): return cls(tuple(PILImage.create(f) for f in fns)) def show(self, ctx=None, **kwargs): t1,t2 = self if not isinstance(t1, Tensor) or not isinstance(t2, Tensor) or t1.shape != t2.shape: return ctx line = t1.new_zeros(t1.shape[0], t1.shape[1], 10) return show_image(torch.cat([t1,line,t2], dim=2), ctx=ctx, **kwargs) img = ImageTuple.create((files[0], files[1])) tst = ToTensor()(img) type(tst[0]),type(tst[1]) img1 = Resize(224)(img) tst = ToTensor()(img1) tst.show(); def ImageTupleBlock(): return TransformBlock(type_tfms=ImageTuple.create, batch_tfms=IntToFloatTensor) splits_files = [files[splits[i]] for i in range(2)] splits_sets = mapped(set, splits_files) def get_split(f): for i,s in enumerate(splits_sets): if f in s: return i raise ValueError(f'File {f} is not presented in any split.') splbl2files = [{l: [f for f in s if label_func(f) == l] for l in labels} for s in splits_sets] def splitter(items): def get_split_files(i): return [j for j,(f1,f2,same) in enumerate(items) if get_split(f1)==i] return get_split_files(0),get_split_files(1) def draw_other(f): same = random.random() < 0.5 cls = label_func(f) split = get_split(f) if not same: cls = random.choice(L(l for l in labels if l != cls)) return random.choice(splbl2files[split][cls]),same def get_tuples(files): return [[f, *draw_other(f)] for f in files] def get_x(t): return t[:2] def get_y(t): return t[2] siamese = DataBlock( blocks=(ImageTupleBlock, CategoryBlock), get_items=get_tuples, get_x=get_x, get_y=get_y, splitter=splitter, item_tfms=Resize(224), batch_tfms=[Normalize.from_stats(*imagenet_stats)] ) dls = siamese.dataloaders(files) b = dls.one_batch() explode_types(b) @typedispatch def show_batch(x:ImageTuple, y, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs): if figsize is None: figsize = (ncols*6, max_n//ncols * 3) if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize) ctxs = show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, **kwargs) return ctxs dls.show_batch() class SiameseModel(Module): def __init__(self, encoder, head): self.encoder,self.head = encoder,head def forward(self, x1, x2): ftrs = torch.cat([self.encoder(x1), self.encoder(x2)], dim=1) return self.head(ftrs) model_meta[resnet34] encoder = create_body(resnet34, cut=-2) encoder[-1] head = create_head(512*2, 2, ps=0.5) model = SiameseModel(encoder, head) head def siamese_splitter(model): return [params(model.encoder), params(model.head)] def loss_func(out, targ): return CrossEntropyLossFlat()(out, targ.long()) class SiameseTransform(Transform): def __init__(self, files, splits): self.splbl2files = [{l: [f for f in files[splits[i]] if label_func(f) == l] for l in labels} for i in range(2)] self.valid = {f: self._draw(f,1) for f in files[splits[1]]} def encodes(self, f): f2,same = self.valid.get(f, self._draw(f,0)) img1,img2 = PILImage.create(f),PILImage.create(f2) return SiameseImage(img1, img2, int(same)) def _draw(self, f, split=0): same = random.random() < 0.5 cls = label_func(f) if not same: cls = random.choice(L(l for l in labels if l != cls)) return random.choice(self.splbl2files[split][cls]),same splits = RandomSplitter()(files) tfm = SiameseTransform(files, splits) tls = TfmdLists(files, tfm, splits=splits) dls = tls.dataloaders(after_item=[Resize(224), ToTensor], after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)]) valids = [v[0] for k,v in tfm.valid.items()] assert not [v for v in valids if v in files[splits[0]]] learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), splitter=siamese_splitter, metrics=accuracy) learn.freeze() learn.lr_find() learn.fit_one_cycle(4, 3e-3) learn.unfreeze() learn.fit_one_cycle(4, slice(1e-6,1e-4)) @typedispatch def show_results(x:SiameseImage, y, samples, outs, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs): if figsize is None: figsize = (ncols*6, max_n//ncols * 3) if ctxs is None: ctxs = get_grid(min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize) for i,ctx in enumerate(ctxs): title = f'Actual: {["Not similar","Similar"][x[2][i].item()]} \n Prediction: {["Not similar","Similar"][y[2][i].item()]}' SiameseImage(x[0][i], x[1][i], title).show(ctx=ctx) learn.show_results() @patch def siampredict(self:Learner, item, rm_type_tfms=None, with_input=False): res = self.predict(item, rm_type_tfms=None, with_input=False) if res[0] == tensor(0): SiameseImage(item[0], item[1], 'Prediction: Not similar').show() else: SiameseImage(item[0], item[1], 'Prediction: Similar').show() return res imgtest = PILImage.create(files[0]) imgval = PILImage.create(files[100]) siamtest = SiameseImage(imgval, imgtest) siamtest.show(); res = learn.siampredict(siamtest)