#hide
#skip
! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab
#default_exp callback.tracker
#export
from fastai.basics import *
from fastai.callback.progress import *
from fastai.callback.fp16 import MixedPrecision
#hide
from nbdev.showdoc import *
from fastai.test_utils import *
Callbacks that make decisions depending how a monitored metric/loss behaves
# export
class TerminateOnNaNCallback(Callback):
"A `Callback` that terminates training if loss is NaN."
order=-9
def after_batch(self):
"Test if `last_loss` is NaN and interrupts training."
if torch.isinf(self.loss) or torch.isnan(self.loss): raise CancelFitException
learn = synth_learner()
learn.fit(10, lr=100, cbs=TerminateOnNaNCallback())
epoch | train_loss | valid_loss | time |
---|
assert len(learn.recorder.losses) < 10 * len(learn.dls.train)
for l in learn.recorder.losses:
assert not torch.isinf(l) and not torch.isnan(l)
# export
class TrackerCallback(Callback):
"A `Callback` that keeps track of the best value in `monitor`."
order,remove_on_fetch = 60,True
def __init__(self, monitor='valid_loss', comp=None, min_delta=0., reset_on_fit=True):
if comp is None: comp = np.less if 'loss' in monitor or 'error' in monitor else np.greater
if comp == np.less: min_delta *= -1
self.monitor,self.comp,self.min_delta,self.reset_on_fit,self.best= monitor,comp,min_delta,reset_on_fit,None
def before_fit(self):
"Prepare the monitored value"
self.run = not hasattr(self, "lr_finder") and not hasattr(self, "gather_preds")
if self.reset_on_fit or self.best is None: self.best = float('inf') if self.comp == np.less else -float('inf')
assert self.monitor in self.recorder.metric_names[1:]
self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)
def after_epoch(self):
"Compare the last value to the best up to now"
val = self.recorder.values[-1][self.idx]
if self.comp(val - self.min_delta, self.best): self.best,self.new_best = val,True
else: self.new_best = False
def after_fit(self): self.run=True
When implementing a Callback
that has behavior that depends on the best value of a metric or loss, subclass this Callback
and use its best
(for best value so far) and new_best
(there was a new best value this epoch) attributes. If you want to maintain best
over subsequent calls to fit
(e.g., Learner.fit_one_cycle
), set reset_on_fit
= True.
comp
is the comparison operator used to determine if a value is best than another (defaults to np.less
if 'loss' is in the name passed in monitor
, np.greater
otherwise) and min_delta
is an optional float that requires a new value to go over the current best (depending on comp
) by at least that amount.
#hide
class FakeRecords(Callback):
order=51
def __init__(self, monitor, values): self.monitor,self.values = monitor,values
def before_fit(self): self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)
def after_epoch(self): self.recorder.values[-1][self.idx] = self.values[self.epoch]
class TestTracker(Callback):
order=61
def before_fit(self): self.bests,self.news = [],[]
def after_epoch(self):
self.bests.append(self.tracker.best)
self.news.append(self.tracker.new_best)
#hide
learn = synth_learner(n_trn=2, cbs=TestTracker())
cbs=[TrackerCallback(monitor='valid_loss'), FakeRecords('valid_loss', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.1])
test_eq(learn.test_tracker.news, [True,True])
#With a min_delta
cbs=[TrackerCallback(monitor='valid_loss', min_delta=0.15), FakeRecords('valid_loss', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.2])
test_eq(learn.test_tracker.news, [True,False])
#hide
#By default metrics have to be bigger at each epoch.
def tst_metric(out,targ): return F.mse_loss(out,targ)
learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)
cbs=[TrackerCallback(monitor='tst_metric'), FakeRecords('tst_metric', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.2])
test_eq(learn.test_tracker.news, [True,False])
#This can be overwritten by passing `comp=np.less`.
learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)
cbs=[TrackerCallback(monitor='tst_metric', comp=np.less), FakeRecords('tst_metric', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.1])
test_eq(learn.test_tracker.news, [True,True])
#hide
#Setting reset_on_fit=True will maintain the "best" value over subsequent calls to fit
learn = synth_learner(n_val=2, cbs=TrackerCallback(monitor='tst_metric', reset_on_fit=False), metrics=tst_metric)
tracker_cb = learn.cbs.filter(lambda cb: isinstance(cb, TrackerCallback))[0]
with learn.no_logging(): learn.fit(1)
first_best = tracker_cb.best
with learn.no_logging(): learn.fit(1)
test_eq(tracker_cb.best, first_best)
#hide
#A tracker callback is not run during an lr_find
from fastai.callback.schedule import *
learn = synth_learner(n_trn=2, cbs=TrackerCallback(monitor='tst_metric'), metrics=tst_metric)
learn.lr_find(num_it=15, show_plot=False)
assert not hasattr(learn, 'new_best')
# export
class EarlyStoppingCallback(TrackerCallback):
"A `TrackerCallback` that terminates training when monitored quantity stops improving."
order=TrackerCallback.order+3
def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1, reset_on_fit=True):
super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)
self.patience = patience
def before_fit(self): self.wait = 0; super().before_fit()
def after_epoch(self):
"Compare the value monitored to its best score and maybe stop training."
super().after_epoch()
if self.new_best: self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
print(f'No improvement since epoch {self.epoch-self.wait}: early stopping')
raise CancelFitException()
comp
is the comparison operator used to determine if a value is best than another (defaults to np.less
if 'loss' is in the name passed in monitor
, np.greater
otherwise) and min_delta
is an optional float that requires a new value to go over the current best (depending on comp
) by at least that amount. patience
is the number of epochs you're willing to wait without improvement.
learn = synth_learner(n_trn=2, metrics=F.mse_loss)
learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='mse_loss', min_delta=0.1, patience=2))
epoch | train_loss | valid_loss | mse_loss | time |
---|---|---|---|---|
0 | 10.651194 | 14.263412 | 14.263412 | 00:00 |
1 | 10.655529 | 14.263385 | 14.263385 | 00:00 |
2 | 10.675529 | 14.263347 | 14.263347 | 00:00 |
No improvement since epoch 0: early stopping
learn.validate()
(#2) [14.263346672058105,14.263346672058105]
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='valid_loss', min_delta=0.1, patience=2))
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 26.303347 | 31.155645 | 00:00 |
1 | 26.319504 | 31.155575 | 00:00 |
2 | 26.335766 | 31.155474 | 00:00 |
No improvement since epoch 0: early stopping
#hide
test_eq(len(learn.recorder.values), 3)
#export
class SaveModelCallback(TrackerCallback):
"A `TrackerCallback` that saves the model's best during training and loads it at the end."
_only_train_loop,order = True,TrackerCallback.order+1
def __init__(self, monitor='valid_loss', comp=None, min_delta=0., fname='model', every_epoch=False, at_end=False,
with_opt=False, reset_on_fit=True):
super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)
assert not (every_epoch and at_end), "every_epoch and at_end cannot both be set to True"
# keep track of file path for loggers
self.last_saved_path = None
store_attr('fname,every_epoch,at_end,with_opt')
def _save(self, name): self.last_saved_path = self.learn.save(name, with_opt=self.with_opt)
def after_epoch(self):
"Compare the value monitored to its best score and save if best."
if self.every_epoch: self._save(f'{self.fname}_{self.epoch}')
else: #every improvement
super().after_epoch()
if self.new_best:
print(f'Better model found at epoch {self.epoch} with {self.monitor} value: {self.best}.')
self._save(f'{self.fname}')
def after_fit(self, **kwargs):
"Load the best model."
if self.at_end: self._save(f'{self.fname}')
elif not self.every_epoch: self.learn.load(f'{self.fname}', with_opt=self.with_opt)
comp
is the comparison operator used to determine if a value is best than another (defaults to np.less
if 'loss' is in the name passed in monitor
, np.greater
otherwise) and min_delta
is an optional float that requires a new value to go over the current best (depending on comp
) by at least that amount. Model will be saved in learn.path/learn.model_dir/name.pth
, maybe every_epoch
or at each improvement of the monitored quantity.
learn = synth_learner(n_trn=2, path=Path.cwd()/'tmp')
learn.fit(n_epoch=2, cbs=SaveModelCallback())
assert (Path.cwd()/'tmp/models/model.pth').exists()
learn = synth_learner(n_trn=2, path=Path.cwd()/'tmp')
learn.fit(n_epoch=2, cbs=SaveModelCallback(fname='end',at_end=True))
assert (Path.cwd()/'tmp/models/end.pth').exists()
learn.fit(n_epoch=2, cbs=SaveModelCallback(every_epoch=True))
for i in range(2): assert (Path.cwd()/f'tmp/models/model_{i}.pth').exists()
shutil.rmtree(Path.cwd()/'tmp')
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 14.472381 | 14.357326 | 00:00 |
1 | 14.362669 | 14.045964 | 00:00 |
Better model found at epoch 0 with valid_loss value: 14.357325553894043. Better model found at epoch 1 with valid_loss value: 14.045964241027832.
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 10.074560 | 11.895357 | 00:00 |
1 | 9.999896 | 11.651211 | 00:00 |
Better model found at epoch 0 with valid_loss value: 11.895357131958008. Better model found at epoch 1 with valid_loss value: 11.65121078491211.
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 9.702823 | 11.311175 | 00:00 |
1 | 9.553051 | 10.910662 | 00:00 |
# export
class ReduceLROnPlateau(TrackerCallback):
"A `TrackerCallback` that reduces learning rate when a metric has stopped improving."
order=TrackerCallback.order+2
def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1, factor=10., min_lr=0, reset_on_fit=True):
super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)
self.patience,self.factor,self.min_lr = patience,factor,min_lr
def before_fit(self): self.wait = 0; super().before_fit()
def after_epoch(self):
"Compare the value monitored to its best score and reduce LR by `factor` if no improvement."
super().after_epoch()
if self.new_best: self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
old_lr = self.opt.hypers[-1]['lr']
for h in self.opt.hypers: h['lr'] = max(h['lr'] / self.factor, self.min_lr)
self.wait = 0
if self.opt.hypers[-1]["lr"] < old_lr:
print(f'Epoch {self.epoch}: reducing lr to {self.opt.hypers[-1]["lr"]}')
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=4, lr=1e-7, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2))
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 6.122743 | 7.348515 | 00:00 |
1 | 6.119377 | 7.348499 | 00:00 |
2 | 6.125790 | 7.348477 | 00:00 |
3 | 6.131386 | 7.348475 | 00:00 |
Epoch 2: reducing lr to 1e-08
#hide
test_eq(learn.opt.hypers[-1]['lr'], 1e-8)
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=6, lr=5e-8, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2, min_lr=1e-8))
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 16.747515 | 15.265999 | 00:00 |
1 | 16.725756 | 15.265974 | 00:00 |
2 | 16.735016 | 15.265943 | 00:00 |
3 | 16.733360 | 15.265934 | 00:00 |
4 | 16.733513 | 15.265925 | 00:00 |
5 | 16.730352 | 15.265915 | 00:00 |
Epoch 2: reducing lr to 1e-08
#hide
test_eq(learn.opt.hypers[-1]['lr'], 1e-8)
Each of these three derived TrackerCallback
s (SaveModelCallback
, ReduceLROnPlateu
, and EarlyStoppingCallback
) all have an adjusted order so they can each run with each other without interference. That order is as follows:
Note: in parenthesis is the actual
Callback
order number
TrackerCallback
(60)SaveModelCallback
(61)ReduceLrOnPlateu
(62)EarlyStoppingCallback
(63)#hide
from nbdev.export import notebook2script
notebook2script()
Converted 00_torch_core.ipynb. Converted 01_layers.ipynb. Converted 01a_losses.ipynb. Converted 02_data.load.ipynb. Converted 03_data.core.ipynb. Converted 04_data.external.ipynb. Converted 05_data.transforms.ipynb. Converted 06_data.block.ipynb. Converted 07_vision.core.ipynb. Converted 08_vision.data.ipynb. Converted 09_vision.augment.ipynb. Converted 09b_vision.utils.ipynb. Converted 09c_vision.widgets.ipynb. Converted 10_tutorial.pets.ipynb. Converted 10b_tutorial.albumentations.ipynb. Converted 11_vision.models.xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_callback.core.ipynb. Converted 13a_learner.ipynb. Converted 13b_metrics.ipynb. Converted 14_callback.schedule.ipynb. Converted 14a_callback.data.ipynb. Converted 15_callback.hook.ipynb. Converted 15a_vision.models.unet.ipynb. Converted 16_callback.progress.ipynb. Converted 17_callback.tracker.ipynb. Converted 18_callback.fp16.ipynb. Converted 18a_callback.training.ipynb. Converted 18b_callback.preds.ipynb. Converted 19_callback.mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision.learner.ipynb. Converted 22_tutorial.imagenette.ipynb. Converted 23_tutorial.vision.ipynb. Converted 24_tutorial.siamese.ipynb. Converted 24_vision.gan.ipynb. Converted 30_text.core.ipynb. Converted 31_text.data.ipynb. Converted 32_text.models.awdlstm.ipynb. Converted 33_text.models.core.ipynb. Converted 34_callback.rnn.ipynb. Converted 35_tutorial.wikitext.ipynb. Converted 36_text.models.qrnn.ipynb. Converted 37_text.learner.ipynb. Converted 38_tutorial.text.ipynb. Converted 39_tutorial.transformers.ipynb. Converted 40_tabular.core.ipynb. Converted 41_tabular.data.ipynb. Converted 42_tabular.model.ipynb. Converted 43_tabular.learner.ipynb. Converted 44_tutorial.tabular.ipynb. Converted 45_collab.ipynb. Converted 46_tutorial.collab.ipynb. Converted 50_tutorial.datablock.ipynb. Converted 60_medical.imaging.ipynb. Converted 61_tutorial.medical_imaging.ipynb. Converted 65_medical.text.ipynb. Converted 70_callback.wandb.ipynb. Converted 71_callback.tensorboard.ipynb. Converted 72_callback.neptune.ipynb. Converted 73_callback.captum.ipynb. Converted 74_callback.azureml.ipynb. Converted 97_test_utils.ipynb. Converted 99_pytorch_doc.ipynb. Converted dev-setup.ipynb. Converted index.ipynb. Converted quick_start.ipynb. Converted tutorial.ipynb.