#default_exp callback.tracker
#export
from local.test import *
from local.basics import *
from local.callback.progress import *
from local.notebook.showdoc import *
from local.test_utils import *
Callbacks that make decisions depending how a monitored metric/loss behaves
#export
class ShortEpochCallback(Callback):
"Fit just `pct` of an epoch, then stop"
def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid
def after_batch(self):
if self.iter/self.n_iter < self.pct: return
if self.training: raise CancelTrainException
if self.short_valid: raise CancelValidException
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback())
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 00:00 |
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback(short_valid=False))
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 4.103489 | 00:00 |
# export
class TerminateOnNaNCallback(Callback):
"A `Callback` that terminates training if loss is NaN."
run_before=Recorder
def after_batch(self):
"Test if `last_loss` is NaN and interrupts training."
if torch.isinf(self.loss) or torch.isnan(self.loss): raise CancelFitException
learn = synth_learner()
learn.fit(10, lr=100, cbs=TerminateOnNaNCallback())
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 2862784155043519244231962471774027776.000000 | 00:00 |
assert len(learn.recorder.losses) < 10 * len(learn.dbunch.train_dl)
for l in learn.recorder.losses:
assert not torch.isinf(l) and not torch.isnan(l)
# export
class TrackerCallback(Callback):
"A `Callback` that keeps track of the best value in `monitor`."
run_after=Recorder
def __init__(self, monitor='valid_loss', comp=None, min_delta=0.):
if comp is None: comp = np.less if 'loss' in monitor else np.greater
if comp == np.less: min_delta *= -1
self.monitor,self.comp,self.min_delta = monitor,comp,min_delta
def begin_fit(self):
"Prepare the monitored value"
self.run = not hasattr(self, "lr_finder") and not hasattr(self, "gather_preds")
self.best = float('inf') if self.comp == np.less else -float('inf')
assert self.monitor in self.recorder.metric_names[1:]
self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)
def after_epoch(self):
"Compare the last value to the best up to know"
val = self.recorder.values[-1][self.idx]
if self.comp(val - self.min_delta, self.best): self.best,self.new_best = val,True
else: self.new_best = False
def after_fit(self): self.run=True
When implementing a Callback
that has behavior that depends on the best value of a metric or loss, subclass this Callback
and use its best
(for best value so far) and new_best
(there was a new best value this epoch) attributes.
comp
is the comparison operator used to determine if a value is best than another (defaults to np.less
if 'loss' is in the name passed in monitor
, np.greater
otherwise) and min_delta
is an optional float that requires a new value to go over the current best (depending on comp
) by at least that amount.
#hide
class FakeRecords(Callback):
run_after=Recorder
run_before=TrackerCallback
def __init__(self, monitor, values): self.monitor,self.values = monitor,values
def begin_fit(self): self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)
def after_epoch(self): self.recorder.values[-1][self.idx] = self.values[self.epoch]
class TestTracker(Callback):
run_after=TrackerCallback
def begin_fit(self): self.bests,self.news = [],[]
def after_epoch(self):
self.bests.append(self.tracker.best)
self.news.append(self.tracker.new_best)
#hide
learn = synth_learner(n_trn=2, cbs=TestTracker())
cbs=[TrackerCallback(monitor='valid_loss'), FakeRecords('valid_loss', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.1])
test_eq(learn.test_tracker.news, [True,True])
#With a min_delta
cbs=[TrackerCallback(monitor='valid_loss', min_delta=0.15), FakeRecords('valid_loss', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.2])
test_eq(learn.test_tracker.news, [True,False])
#hide
#By default metrics have to be bigger at each epoch.
def tst_metric(out,targ): return F.mse_loss(out,targ)
learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)
cbs=[TrackerCallback(monitor='tst_metric'), FakeRecords('tst_metric', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.2])
test_eq(learn.test_tracker.news, [True,False])
#This can be overwritten by passing `comp=np.less`.
learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)
cbs=[TrackerCallback(monitor='tst_metric', comp=np.less), FakeRecords('tst_metric', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.1])
test_eq(learn.test_tracker.news, [True,True])
#hide
#A tracker callback is not run during an lr_find
from local.callback.schedule import *
learn = synth_learner(n_trn=2, cbs=TrackerCallback(monitor='tst_metric'), metrics=tst_metric)
learn.lr_find(num_it=5, show_plot=False)
assert not hasattr(learn, 'new_best')
# export
class EarlyStoppingCallback(TrackerCallback):
"A `TrackerCallback` that terminates training when monitored quantity stops improving."
def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1):
super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)
self.patience = patience
def begin_fit(self): self.wait = 0; super().begin_fit()
def after_epoch(self):
"Compare the value monitored to its best score and maybe stop training."
super().after_epoch()
if self.new_best: self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
print(f'No improvement since epoch {self.epoch-self.wait}: early stopping')
raise CancelFitException()
comp
is the comparison operator used to determine if a value is best than another (defaults to np.less
if 'loss' is in the name passed in monitor
, np.greater
otherwise) and min_delta
is an optional float that requires a new value to go over the current best (depending on comp
) by at least that amount. patience
is the number of epochs you're willing to wait without improvement.
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='valid_loss', min_delta=0.1, patience=2))
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 8.969569 | 7.550271 | 00:00 |
1 | 8.964047 | 7.550250 | 00:00 |
2 | 8.963997 | 7.550220 | 00:00 |
No improvement since epoch 0: early stopping
#hide
test_eq(len(learn.recorder.values), 3)
# export
class SaveModelCallback(TrackerCallback):
"A `TrackerCallback` that saves the model's best during training and loads it at the end."
def __init__(self, monitor='valid_loss', comp=None, min_delta=0., fname='model', every_epoch=False, add_save=None, with_opt=False):
super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)
store_attr(self, 'fname,every_epoch,add_save,with_opt')
def _save(self, name):
self.learn.save(name, with_opt=self.with_opt)
if self.add_save is not None:
with self.add_save.open('wb') as f: self.learn.save(f, with_opt=self.with_opt)
def after_epoch(self):
"Compare the value monitored to its best score and save if best."
if self.every_epoch: self._save(f'{self.fname}_{self.epoch}')
else: #every improvement
super().after_epoch()
if self.new_best: self._save(f'{self.fname}')
def on_train_end(self, **kwargs):
"Load the best model."
if not self.every_epoch: self.learn.load(f'{self.fname}')
comp
is the comparison operator used to determine if a value is best than another (defaults to np.less
if 'loss' is in the name passed in monitor
, np.greater
otherwise) and min_delta
is an optional float that requires a new value to go over the current best (depending on comp
) by at least that amount. Model will be saved in learn.path/learn.model_dir/name.pth
, maybe every_epoch
or at each improvement of the monitored quantity.
learn = synth_learner(n_trn=2, path=Path.cwd()/'tmp')
learn.fit(n_epoch=2, cbs=SaveModelCallback())
assert (Path.cwd()/'tmp/models/model.pth').exists()
learn.fit(n_epoch=2, cbs=SaveModelCallback(every_epoch=True))
for i in range(2): assert (Path.cwd()/f'tmp/models/model_{i}.pth').exists()
shutil.rmtree(Path.cwd()/'tmp')
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 12.352231 | 9.437100 | 00:00 |
1 | 12.247625 | 9.188762 | 00:00 |
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 11.802291 | 8.847794 | 00:00 |
1 | 11.562449 | 8.428464 | 00:00 |
# export
class ReduceLROnPlateau(TrackerCallback):
"A `TrackerCallback` that reduces learning rate when a metric has stopped improving."
def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1, factor=10.):
super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)
self.patience,self.factor = patience,factor
def begin_fit(self): self.wait = 0; super().begin_fit()
def after_epoch(self):
"Compare the value monitored to its best score and reduce LR by `factor` if no improvement."
super().after_epoch()
if self.new_best: self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
for h in self.opt.hypers: h['lr'] /= self.factor
self.wait = 0
print(f'Epoch {self.epoch}: reducing lr to {self.opt.hypers[-1]["lr"]}')
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=4, lr=1e-7, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2))
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 13.372304 | 11.866179 | 00:00 |
1 | 13.377043 | 11.866154 | 00:00 |
2 | 13.395395 | 11.866118 | 00:00 |
3 | 13.400988 | 11.866114 | 00:00 |
Epoch 2: reducing lr to 1e-08
#hide
test_eq(learn.opt.hypers[-1]['lr'], 1e-8)
#hide
from local.notebook.export import notebook2script
notebook2script(all_fs=True)
Converted 00_test.ipynb. Converted 01_core_foundation.ipynb. Converted 01a_core_utils.ipynb. Converted 01b_core_dispatch.ipynb. Converted 01c_core_transform.ipynb. Converted 02_core_script.ipynb. Converted 03_torchcore.ipynb. Converted 03a_layers.ipynb. Converted 04_data_load.ipynb. Converted 05_data_core.ipynb. Converted 06_data_transforms.ipynb. Converted 07_data_block.ipynb. Converted 08_vision_core.ipynb. Converted 09_vision_augment.ipynb. Converted 09a_vision_data.ipynb. Converted 09b_vision_utils.ipynb. Converted 10_pets_tutorial.ipynb. Converted 11_vision_models_xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_learner.ipynb. Converted 13a_metrics.ipynb. Converted 14_callback_schedule.ipynb. Converted 14a_callback_data.ipynb. Converted 15_callback_hook.ipynb. Converted 15a_vision_models_unet.ipynb. Converted 16_callback_progress.ipynb. Converted 17_callback_tracker.ipynb. Converted 18_callback_fp16.ipynb. Converted 19_callback_mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision_learner.ipynb. Converted 22_tutorial_imagenette.ipynb. Converted 23_tutorial_transfer_learning.ipynb. Converted 30_text_core.ipynb. Converted 31_text_data.ipynb. Converted 32_text_models_awdlstm.ipynb. Converted 33_text_models_core.ipynb. Converted 34_callback_rnn.ipynb. Converted 35_tutorial_wikitext.ipynb. Converted 36_text_models_qrnn.ipynb. Converted 37_text_learner.ipynb. Converted 38_tutorial_ulmfit.ipynb. Converted 40_tabular_core.ipynb. Converted 41_tabular_model.ipynb. Converted 42_tabular_rapids.ipynb. Converted 50_data_block_examples.ipynb. Converted 60_medical_imaging.ipynb. Converted 65_medical_text.ipynb. Converted 70_callback_wandb.ipynb. Converted 71_callback_tensorboard.ipynb. Converted 90_notebook_core.ipynb. Converted 91_notebook_export.ipynb. Converted 92_notebook_showdoc.ipynb. Converted 93_notebook_export2html.ipynb. Converted 94_notebook_test.ipynb. Converted 95_index.ipynb. Converted 96_data_external.ipynb. Converted 97_utils_test.ipynb. Converted notebook2jekyll.ipynb. Converted xse_resnext.ipynb.