# default_exp interpret #export from local.test import * from local.data.all import * from local.optimizer import * from local.learner import * #hide from local.test_utils import * #export @typedispatch def plot_top_losses(x, y, *args, **kwargs): raise Exception(f"plot_top_losses is not implemented for {type(x)},{type(y)}") #export _all_ = ["plot_top_losses"] #export class Interpretation(): "Interpretation base class, can be inherited for task specific Interpretation classes" def __init__(self, dl, inputs, preds, targs, decoded, losses): store_attr(self, "dl,inputs,preds,targs,decoded,losses") @classmethod def from_learner(cls, learn, ds_idx=1, dl=None, act=None): "Construct interpretatio object from a learner" if dl is None: dl = learn.dbunch.dls[ds_idx] return cls(dl, *learn.get_preds(dl=dl, with_input=True, with_loss=True, with_decoded=True, act=None)) def top_losses(self, k=None, largest=True): "`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`)." return self.losses.topk(ifnone(k, len(self.losses)), largest=largest) def plot_top_losses(self, k, largest=True, **kwargs): losses,idx = self.top_losses(k, largest) if isinstance(self.inputs[0], Tensor): inps = tuple(o[idx] for o in self.inputs) else: inps = self.dl.create_batch(self.dl.before_batch([tuple(o[i] for o in self.inputs) for i in idx])) b = inps + tuple(o[idx] for o in (self.targs if is_listy(self.targs) else (self.targs,))) x,y,its = self.dl._pre_show_batch(b, max_n=k) b_out = inps + tuple(o[idx] for o in (self.decoded if is_listy(self.decoded) else (self.decoded,))) x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k) if its is not None: plot_top_losses(x, y, its, outs.itemgot(slice(len(self.inputs), None)), self.preds[idx], losses, **kwargs) #TODO: figure out if this is needed #its None means that a batch knos how to show itself as a whole, so we pass x, x1 #else: show_results(x, x1, its, ctxs=ctxs, max_n=max_n, **kwargs) learn = synth_learner() interp = Interpretation.from_learner(learn) x,y = learn.dbunch.valid_ds.tensors test_eq(*interp.inputs, x) test_eq(interp.targs, y) out = learn.model.a * x + learn.model.b test_eq(interp.preds, out) test_eq(interp.losses, (out-y)[:,0]**2) #export class ClassificationInterpretation(Interpretation): "Interpretation methods for classification models." def __init__(self, dl, inputs, preds, targs, decoded, losses): super().__init__(dl, inputs, preds, targs, decoded, losses) self.vocab = self.dl.vocab if is_listy(self.vocab): self.vocab = self.vocab[-1] def confusion_matrix(self): "Confusion matrix as an `np.ndarray`." x = torch.arange(0, len(self.vocab)) cm = ((self.decoded==x[:,None]) & (self.targs==x[:,None,None])).sum(2) return to_np(cm) def plot_confusion_matrix(self, normalize=False, title='Confusion matrix', cmap="Blues", norm_dec=2, plot_txt=True, **kwargs): "Plot the confusion matrix, with `title` and using `cmap`." # This function is mainly copied from the sklearn docs cm = self.confusion_matrix() if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] fig = plt.figure(**kwargs) plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.title(title) tick_marks = np.arange(len(self.vocab)) plt.xticks(tick_marks, self.vocab, rotation=90) plt.yticks(tick_marks, self.vocab, rotation=0) if plot_txt: thresh = cm.max() / 2. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): coeff = f'{cm[i, j]:.{norm_dec}f}' if normalize else f'{cm[i, j]}' plt.text(j, i, coeff, horizontalalignment="center", verticalalignment="center", color="white" if cm[i, j] > thresh else "black") ax = fig.gca() ax.set_ylim(len(self.vocab)-.5,-.5) plt.tight_layout() plt.ylabel('Actual') plt.xlabel('Predicted') plt.grid(False) def most_confused(self, min_val=1): "Sorted descending list of largest non-diagonal entries of confusion matrix, presented as actual, predicted, number of occurrences." cm = self.confusion_matrix() np.fill_diagonal(cm, 0) res = [(self.vocab[i],self.vocab[j],cm[i,j]) for i,j in zip(*np.where(cm>=min_val))] return sorted(res, key=itemgetter(2), reverse=True) #hide from local.notebook.export import notebook2script notebook2script(all_fs=True)