! pip install -q fastai2 ! pip install -q psutil from google.colab import drive drive.mount('/content/drive') from pathlib import Path # save models on gdrive path = Path('/content/drive/My Drive/') assert path.is_dir() # save models in local folder # path = Path('.').resolve() from fastai2.basics import * from fastai2.vision.all import * from fastai2.callback.all import * from fastai2.distributed import * from fastai2.optimizer import * class SkipMetricPartException(Exception): pass class AvgPartMetric(AvgMetric): "Average the values of `func` allowing to raise exception and omit accumulation" def accumulate(self, learn): try: val, bs = self.func(learn.pred, *learn.yb) self.total += to_detach(val)*bs self.count += bs except SkipMetricPartException: pass except Exception as e: print(e) def accuracy_with_na(inp, targ, thresh=0.4, na_idx=0, axis=-1, sigmoid=True): "Compute accuracy assuming that prediction below threshold belongs to first category (#na#)" if sigmoid: inp = inp.sigmoid() valm, argm = inp.max(dim=axis) argm[valm < thresh] = na_idx # treat all values below threshold as category #na# inp, targ = flatten_check(argm, targ) return (inp==targ).float().mean() def _accuracy_without_na(inp, targ, na_idx=0, axis=-1): "Compute accuracy with `targ` when `pred` omiting #na# category" idxs = targ!=na_idx if idxs.any(): inp, targ = flatten_check(inp[idxs].argmax(dim=axis), targ[idxs]) return (inp==targ).float().mean(), targ.shape[axis] else: # skip accumulating metric if there is only category #na# in batch raise SkipMetricPartException _accuracy_without_na.__name__ = 'accuracy_without_na' accuracy_without_na = AvgPartMetric(_accuracy_without_na) from fastai2.basics import * from fastai2.vision.all import * from fastai2.callback.all import * from torch import nn class BCENaLoss(nn.Module): y_int = True def __init__(self, logits=True, reduction='mean'): super().__init__() self.reduction = reduction self.logits = logits def forward(self, input, target): target = F.one_hot(target, input.shape[1]).float() target[:, 0] = 0 # first category is #na# category so it should be zeroed if self.logits: return F.binary_cross_entropy_with_logits(input, target, reduction=self.reduction) # sigmoid + bce else: return F.binary_cross_entropy(input, target, reduction=self.reduction) # no sigmoid @delegates(keep=True) class BCENaLossFlat(BaseLoss): "Same as `FocalLoss`, but flattens input and target." def __init__(self, *args, axis=-1, thresh=0.5, **kwargs): super().__init__(BCENaLoss, *args, axis=axis, **kwargs) self.thresh = thresh def decodes(self, x): valm, argm = x.max(dim=self.axis) argm[valm < self.thresh] = 0 return argm def activation(self, x): return torch.sigmoid(x) @delegates() class BCEWithLogitsLossOneHotFlat(BCEWithLogitsLossFlat): def __call__(self, inp, targ, **kwargs): return super().__call__(inp, F.one_hot(targ, inp.shape[1]), **kwargs) def decodes(self, x): return x.argmax(dim=-1) def activation(self, x): return torch.sigmoid(x) def get_train_dls(bs=48, size=128, workers=None, augs=True, item_tfms=[], batch_tfms=[], add_na=False, train_na=None): dspath = untar_data(URLs.IMAGENETTE_160) if workers is None: workers = min(8, num_cpus()//(num_distrib() or 1)) norm_tfms = [Normalize.from_stats(*imagenet_stats)] resize_tfms = [Resize(size, method=ResizeMethod.Pad, pad_mode=PadMode.Reflection)] augs_tfms = aug_transforms() if augs else [] # categories known and trained train_cats = [ 'n03417042', 'n02979186', 'n03394916', 'n03445777', ] # categories na and trained na_train_cats = [ 'n03028079', 'n01440764', 'n03888257', ] # categories not trained (unseen by network) test_cats = [ 'n02102040', 'n03000684', 'n03425413' ] if train_na is None: train_na = add_na def get_items(dspath): if train_na: # train and validate known and unknown categories train_folders = train_cats + na_train_cats valid_folders = train_cats + na_train_cats else: # train and validate only known categories train_folders = train_cats valid_folders = train_cats return get_image_files(dspath/'train', folders=train_folders) + get_image_files(dspath/'val', folders=valid_folders) # create train and valid dataloaders dbl = DataBlock( blocks=(ImageBlock, CategoryBlock(vocab=train_cats, add_na=add_na)), get_items=get_items, get_y=parent_label, splitter=GrandparentSplitter(train_name='train', valid_name='val'), item_tfms=item_tfms + resize_tfms, batch_tfms=batch_tfms + augs_tfms + norm_tfms, ) dls = dbl.dataloaders(dspath, bs=bs, num_workers=workers) if not add_na: # support any unknown category as random from existing categories # it does not affect training, only enables validation for not known categories items = dls.valid.tfms[1][1].vocab.items o2i = dls.valid.tfms[1][1].vocab.o2i dls.valid.tfms[1][1].vocab.o2i = defaultdict(lambda: random.randint(0, len(items) - 1), o2i) # add custom test dataloader to validate all cats (even unseen before) test_items = get_image_files(dspath, folders=['val']) dls.test = dls.test_dl(test_items, with_labels=True) return dls dls = get_train_dls() dls.show_batch() dls = get_train_dls(add_na=True) dls.show_batch() dls.test.show_batch() learn = cnn_learner( get_train_dls(), xresnet18, pretrained=True, loss_func=BCEWithLogitsLossOneHotFlat(), metrics=[accuracy], path=path, cbs=[ShowGraphCallback] ) if torch.cuda.is_available(): learn = learn.to_fp16() learn.fine_tune(20, freeze_epochs=2) # confusion_matrix for known categories interp = ClassificationInterpretation.from_learner(learn) interp.plot_confusion_matrix(figsize=(12,12), dpi=60) learn.remove_cb(ShowGraphCallback) # fix fastai callback error during validation v_loss, v_accuracy = learn.validate(dl=learn.dls.test) # validate with all known, unknown and unseen categories print(f'validation loss: {v_loss}, [accuracy]: {v_accuracy}') # confusion_matrix for known, unknown and unseen categories interp = ClassificationInterpretation.from_learner(learn, dl=learn.dls.test) interp.plot_confusion_matrix(figsize=(12,12), dpi=60) learn.save('bce-loss') learn = cnn_learner( get_train_dls(add_na=True), xresnet18, pretrained=True, loss_func=BCEWithLogitsLossOneHotFlat(), metrics=[accuracy, accuracy_without_na, accuracy_with_na], path=path, cbs=[ShowGraphCallback] ) if torch.cuda.is_available(): learn = learn.to_fp16() learn.fine_tune(20, freeze_epochs=2) # confusion_matrix for known and unknown categories interp = ClassificationInterpretation.from_learner(learn) interp.plot_confusion_matrix(figsize=(12,12), dpi=60) learn.remove_cb(ShowGraphCallback) # fix fastai callback error during validation mets = learn.validate(dl=learn.dls.test) # validate with all known, unknown and unseen categories print(f'validation loss: {mets[0]}, accuracy: {mets[1]}, accuracy without na: {mets[2]}, [accuracy with n1]: {mets[3]}') # confusion_matrix for known, unknown and unseen categories interp = ClassificationInterpretation.from_learner(learn, dl=learn.dls.test) interp.plot_confusion_matrix(figsize=(12,12), dpi=60) learn.save('bce-loss_with_na') learn = cnn_learner( get_train_dls(add_na=True), xresnet18, pretrained=True, loss_func=BCENaLossFlat(), metrics=[accuracy, accuracy_without_na, accuracy_with_na], path=path, cbs=[ShowGraphCallback] ) if torch.cuda.is_available(): learn = learn.to_fp16() learn.fine_tune(20, freeze_epochs=2) interp = ClassificationInterpretation.from_learner(learn) interp.plot_confusion_matrix(figsize=(12,12), dpi=60) learn.remove_cb(ShowGraphCallback) # fix fastai callback error during validation mets = learn.validate(dl=learn.dls.test) # validate with all known, unknown and unseen categories print(f'validation loss: {mets[0]}, accuracy: {mets[1]}, accuracy without na: {mets[2]}, [accuracy with n1]: {mets[3]}') # confusion_matrix for known, unknown and unseen categories interp = ClassificationInterpretation.from_learner(learn, dl=learn.dls.test) interp.plot_confusion_matrix(figsize=(12,12), dpi=60) learn.save('bce-na-loss_with_na') # function to print and compare above experiments def validate_experiment(add_na=False, bcena=False): # get learner loss_func = BCENaLossFlat() if bcena else BCEWithLogitsLossOneHotFlat() metrics = [accuracy, accuracy_with_na, accuracy_without_na] if add_na else [accuracy] learn = cnn_learner( get_train_dls(add_na=add_na), xresnet18, pretrained=True, loss_func=loss_func, metrics=metrics, path=path ) if not add_na and not bcena: lpath = 'bce-loss' if add_na and not bcena: lpath = 'bce-loss_with_na' if add_na and bcena: lpath = 'bce-na-loss_with_na' learn = learn.load(lpath) if torch.cuda.is_available(): learn = learn.to_fp16() # validate without na dls = get_train_dls(add_na=add_na, train_na=False) interp = ClassificationInterpretation.from_learner(learn, dl=dls.valid) d,t = flatten_check(interp.decoded, interp.targs) known_tp = (d==t).long().sum() # validate with na interp = ClassificationInterpretation.from_learner(learn, dl=learn.dls.test) d,t = flatten_check(interp.decoded, interp.targs) mets = learn.final_record acc = mets[-1] if add_na: i = t==0 j = d==0 unknown_tp = (d[i]==t[i]).long().sum() # unknown true positive unknown_fn = (d[i]!=t[i]).long().sum() # unknown false negative unknown_fp = (d[j]!=t[j]).long().sum() # unknown false positive else: unknown_tp = 0 unknown_fn = 2386 # all unknown pictures are false negative here unknown_fp = 0 pr = f'{lpath:^20} => known TP: {known_tp:4d} | #na# TP: {unknown_tp:4d} | #na# FN: {unknown_fn:4d} | #na# FP: {unknown_fp:4d} | accuracy: {acc:2.4f}' print(pr) validate_experiment() validate_experiment(add_na=True) validate_experiment(add_na=True, bcena=True)