#|hide
#| eval: false
! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab
#|export
from __future__ import annotations
from fastai.basics import *
#|hide
from nbdev.showdoc import *
#|default_exp callback.rnn
Callback that uses the outputs of language models to add AR and TAR regularization
#|export
@docs
class ModelResetter(Callback):
"`Callback` that resets the model at each validation/training step"
def before_train(self): self.model.reset()
def before_validate(self): self.model.reset()
def after_fit(self): self.model.reset()
_docs = dict(before_train="Reset the model before training",
before_validate="Reset the model before validation",
after_fit="Reset the model after fitting")
#|export
class RNNCallback(Callback):
"Save the raw and dropped-out outputs and only keep the true output for loss computation"
def after_pred(self): self.learn.pred,self.raw_out,self.out = [o[-1] if is_listy(o) else o for o in self.pred]
#|export
class RNNRegularizer(Callback):
"Add AR and TAR regularization"
order,run_valid = RNNCallback.order+1,False
def __init__(self, alpha=0., beta=0.): store_attr()
def after_loss(self):
if not self.training: return
if self.alpha: self.learn.loss_grad += self.alpha * self.rnn.out.float().pow(2).mean()
if self.beta:
h = self.rnn.raw_out
if len(h)>1: self.learn.loss_grad += self.beta * (h[:,1:] - h[:,:-1]).float().pow(2).mean()
#|export
def rnn_cbs(alpha=0., beta=0.):
"All callbacks needed for (optionally regularized) RNN training"
reg = [RNNRegularizer(alpha=alpha, beta=beta)] if alpha or beta else []
return [ModelResetter(), RNNCallback()] + reg
#|hide
from nbdev import nbdev_export
nbdev_export()
Converted 00_torch_core.ipynb. Converted 01_layers.ipynb. Converted 01a_losses.ipynb. Converted 02_data.load.ipynb. Converted 03_data.core.ipynb. Converted 04_data.external.ipynb. Converted 05_data.transforms.ipynb. Converted 06_data.block.ipynb. Converted 07_vision.core.ipynb. Converted 08_vision.data.ipynb. Converted 09_vision.augment.ipynb. Converted 09b_vision.utils.ipynb. Converted 09c_vision.widgets.ipynb. Converted 10_tutorial.pets.ipynb. Converted 10b_tutorial.albumentations.ipynb. Converted 11_vision.models.xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_callback.core.ipynb. Converted 13a_learner.ipynb. Converted 13b_metrics.ipynb. Converted 14_callback.schedule.ipynb. Converted 14a_callback.data.ipynb. Converted 15_callback.hook.ipynb. Converted 15a_vision.models.unet.ipynb. Converted 16_callback.progress.ipynb. Converted 17_callback.tracker.ipynb. Converted 18_callback.fp16.ipynb. Converted 18a_callback.training.ipynb. Converted 18b_callback.preds.ipynb. Converted 19_callback.mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision.learner.ipynb. Converted 22_tutorial.imagenette.ipynb. Converted 23_tutorial.vision.ipynb. Converted 24_tutorial.siamese.ipynb. Converted 24_vision.gan.ipynb. Converted 30_text.core.ipynb. Converted 31_text.data.ipynb. Converted 32_text.models.awdlstm.ipynb. Converted 33_text.models.core.ipynb. Converted 34_callback.rnn.ipynb. Converted 35_tutorial.wikitext.ipynb. Converted 36_text.models.qrnn.ipynb. Converted 37_text.learner.ipynb. Converted 38_tutorial.text.ipynb. Converted 39_tutorial.transformers.ipynb. Converted 40_tabular.core.ipynb. Converted 41_tabular.data.ipynb. Converted 42_tabular.model.ipynb. Converted 43_tabular.learner.ipynb. Converted 44_tutorial.tabular.ipynb. Converted 45_collab.ipynb. Converted 46_tutorial.collab.ipynb. Converted 50_tutorial.datablock.ipynb. Converted 60_medical.imaging.ipynb. Converted 61_tutorial.medical_imaging.ipynb. Converted 65_medical.text.ipynb. Converted 70_callback.wandb.ipynb. Converted 71_callback.tensorboard.ipynb. Converted 72_callback.neptune.ipynb. Converted 73_callback.captum.ipynb. Converted 97_test_utils.ipynb. Converted 99_pytorch_doc.ipynb. Converted dev-setup.ipynb. Converted index.ipynb. Converted quick_start.ipynb. Converted tutorial.ipynb.