#hide
#skip
! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab
#default_exp callback.training
#export
from fastai.basics import *
from fastai.callback.progress import *
from fastai.callback.fp16 import *
#hide
from nbdev.showdoc import *
from fastai.test_utils import *
from fastai.vision.all import *
Various callbacks to customize training behavior
#export
class ShortEpochCallback(Callback):
"Fit just `pct` of an epoch, then stop"
def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid
def after_batch(self):
if self.iter/self.n_iter < self.pct: return
if self.training: raise CancelTrainException
if self.short_valid: raise CancelValidException
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback())
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 00:00 |
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback(short_valid=False))
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 14.867975 | 00:00 |
# export
class GradientAccumulation(Callback):
"Accumulate gradients before updating weights"
order,run_valid = MixedPrecision.order-4,False
def __init__(self, n_acc=32): store_attr()
def before_fit(self): self.count=0
def after_loss(self): self.learn.loss_grad /= self.n_acc/find_bs(self.learn.yb)
def before_step(self):
"Skip weight update if we have not seen enough items"
self.learn.loss_grad *= self.n_acc/find_bs(self.learn.yb) # log correct loss
self.count += find_bs(self.learn.yb)
if self.count<self.n_acc: raise CancelBatchException() # skip step/zero_grad
else: self.count=0
#hide
class GetGrads(Callback):
run_valid,order = False,GradientAccumulation.order+1
def before_step(self): self.grads=to_detach(L([p.grad.clone() for p in self.model.parameters()]))
def _test_acc(bs,n,cbs=None,cuda=False):
with no_random(99):
db=synth_dbunch(bs=bs,n_train=n,n_valid=n,cuda=cuda)
learn = synth_learner(data=db,cbs=[GetGrads]+L(cbs))
learn.fit(1, lr=0.01)
train,valid = learn.recorder.values[-1]
return train,valid,learn.get_grads.grads
acc_cb = GradientAccumulation(n_acc=8)
train1,valid1,grads1 = _test_acc(8,1)
train2,valid2,grads2 = _test_acc(1,8,acc_cb)
#grads should be same, valid loss same, train loss different
test_close(grads2,grads1)
test_close(valid2, valid1)
test_ne(train2, train1)
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 0.834062 | 0.295950 | 00:00 |
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 0.824550 | 0.295950 | 00:00 |
#hide
#cuda
fp16_cb = MixedPrecision(init_scale=1024)
train1,valid1,grads1 = _test_acc(8,1, fp16_cb, cuda=True)
train2,valid2,grads2 = _test_acc(1,8, [acc_cb,fp16_cb], cuda=True)
test_close(grads2,grads1, eps=0.01)
test_close(valid2, valid1)
test_ne(train2, train1)
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 0.834062 | 0.295950 | 00:00 |
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 0.824550 | 0.295950 | 00:00 |
When the number of steps per accumulation is higher than the number of batches, the parameters (and therefore validation loss) don't change at all:
learn = synth_learner()
learn.fit(1, lr=0.01, cbs=GradientAccumulation(n_acc=1000))
# ensure valid_loss didn't change
assert learn.recorder.values[-1][1] == learn.recorder.values[0][1]
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 10.941168 | 10.280428 | 00:00 |
# export
class GradientClip(Callback):
"Clip norm of gradients"
order=MixedPrecision.order+1
def __init__(self,max_norm:float=1., norm_type:float=2.0): store_attr()
def before_step(self): nn.utils.clip_grad_norm_(self.parameters(), self.max_norm, self.norm_type)
Normally if we use a learning rate that is too high, our training will diverge. This even happens if we use mixed precision training, which avoid infinities by using dynamic loss scaling, but still diverges:
fp16 = MixedPrecision()
set_seed(99)
learn = synth_learner(lr=1.1, cuda=True)
learn.fit(3, cbs=fp16)
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 38.214169 | 25.269012 | 00:00 |
1 | 377.146088 | 890.011780 | 00:00 |
2 | 839.391907 | 9965.712891 | 00:00 |
By adding the GradientClip
callback, the gradient norm_type
(default:2) norm is clipped to at most max_norm
(default:1) using nn.utils.clip_grad_norm_
, which can avoid loss divergence:
set_seed(99)
learn = synth_learner(lr=1.1, cuda=True)
learn.fit(3, cbs=[GradientClip,fp16])
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 2.039427 | 2.372183 | 00:00 |
1 | 1.402424 | 0.300724 | 00:00 |
2 | 1.013551 | 0.332668 | 00:00 |
#export
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
def set_bn_eval(m:nn.Module, use_eval=True)->None:
"Set bn layers in eval mode for all recursive children of `m`."
for l in m.children():
if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:
if use_eval: l.eval()
else: l.train()
set_bn_eval(l)
class BnFreeze(Callback):
run_after=TrainEvalCallback
"Freeze moving average statistics in all non-trainable batchnorm layers."
def before_train(self):
set_bn_eval(self.model)
BnFreeze
is useful when you'd like to train two separate models that have a common feature extractor / body. The only part of the model that's different is the head that you attach for transfer learning.
Learner.freeze()
doesn't suffice here as the BatchNorm
layers are trainable by default, and running mean and std of batches are tracked. For feature extractors to fully match, you need to set train_bn=False
and these stats need to be frozen as well, which is precisely the function of BnFreeze
.
#slow
path = untar_data(URLs.MNIST_TINY)
dls = ImageDataLoaders.from_folder(path, valid_pct=0.2)
We first demonstrate the mismatch of the running stats when using only train_bn=False
, by creating a Learner
...:
#slow
learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False)
...and grab the first BatchNorm
layer, and store its running mean:
#slow
m = learn1.model[0][1].running_mean.clone()
You can see that now that running mean has changed:
#slow
learn1.fit(1, lr=0.02)
test_ne(to_detach(learn1.model[0][1].running_mean), m)
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 1.152701 | 0.468892 | 00:02 |
When we use the BnFreeze
callback, the running statistics will not be changed during training. This is often important for getting good results from transfer learning.
#slow
learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False, cbs=BnFreeze)
m = learn1.model[0][1].running_mean.detach().clone()
learn1.fit(1, lr=0.02)
test_eq(to_detach(learn1.model[0][1].running_mean), m)
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 0.488634 | 0.277683 | 00:02 |
#hide
from nbdev.export import notebook2script
notebook2script()
Converted 00_torch_core.ipynb. Converted 01_layers.ipynb. Converted 01a_losses.ipynb. Converted 02_data.load.ipynb. Converted 03_data.core.ipynb. Converted 04_data.external.ipynb. Converted 05_data.transforms.ipynb. Converted 06_data.block.ipynb. Converted 07_vision.core.ipynb. Converted 08_vision.data.ipynb. Converted 09_vision.augment.ipynb. Converted 09b_vision.utils.ipynb. Converted 09c_vision.widgets.ipynb. Converted 10_tutorial.pets.ipynb. Converted 10b_tutorial.albumentations.ipynb. Converted 11_vision.models.xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_callback.core.ipynb. Converted 13a_learner.ipynb. Converted 13b_metrics.ipynb. Converted 14_callback.schedule.ipynb. Converted 14a_callback.data.ipynb. Converted 15_callback.hook.ipynb. Converted 15a_vision.models.unet.ipynb. Converted 16_callback.progress.ipynb. Converted 17_callback.tracker.ipynb. Converted 18_callback.fp16.ipynb. Converted 18a_callback.training.ipynb. Converted 18b_callback.preds.ipynb. Converted 19_callback.mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision.learner.ipynb. Converted 22_tutorial.imagenette.ipynb. Converted 23_tutorial.vision.ipynb. Converted 24_tutorial.siamese.ipynb. Converted 24_vision.gan.ipynb. Converted 30_text.core.ipynb. Converted 31_text.data.ipynb. Converted 32_text.models.awdlstm.ipynb. Converted 33_text.models.core.ipynb. Converted 34_callback.rnn.ipynb. Converted 35_tutorial.wikitext.ipynb. Converted 36_text.models.qrnn.ipynb. Converted 37_text.learner.ipynb. Converted 38_tutorial.text.ipynb. Converted 39_tutorial.transformers.ipynb. Converted 40_tabular.core.ipynb. Converted 41_tabular.data.ipynb. Converted 42_tabular.model.ipynb. Converted 43_tabular.learner.ipynb. Converted 44_tutorial.tabular.ipynb. Converted 45_collab.ipynb. Converted 46_tutorial.collab.ipynb. Converted 50_tutorial.datablock.ipynb. Converted 60_medical.imaging.ipynb. Converted 61_tutorial.medical_imaging.ipynb. Converted 65_medical.text.ipynb. Converted 70_callback.wandb.ipynb. Converted 71_callback.tensorboard.ipynb. Converted 72_callback.neptune.ipynb. Converted 73_callback.captum.ipynb. Converted 97_test_utils.ipynb. Converted 99_pytorch_doc.ipynb. Converted dev-setup.ipynb. Converted index.ipynb. Converted quick_start.ipynb. Converted tutorial.ipynb.