#hide
#skip
! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab
#export
from fastai.basics import *
#hide
from nbdev.showdoc import *
#default_exp callback.progress
Callback and helper function to track progress of training or log results
from fastai.test_utils import *
# export
@docs
class ProgressCallback(Callback):
"A `Callback` to handle the display of progress bars"
order,_stateattrs = 60,('mbar','pbar')
def before_fit(self):
assert hasattr(self.learn, 'recorder')
if self.create_mbar: self.mbar = master_bar(list(range(self.n_epoch)))
if self.learn.logger != noop:
self.old_logger,self.learn.logger = self.logger,self._write_stats
self._write_stats(self.recorder.metric_names)
else: self.old_logger = noop
def before_epoch(self):
if getattr(self, 'mbar', False): self.mbar.update(self.epoch)
def before_train(self): self._launch_pbar()
def before_validate(self): self._launch_pbar()
def after_train(self): self.pbar.on_iter_end()
def after_validate(self): self.pbar.on_iter_end()
def after_batch(self):
self.pbar.update(self.iter+1)
if hasattr(self, 'smooth_loss'): self.pbar.comment = f'{self.smooth_loss:.4f}'
def _launch_pbar(self):
self.pbar = progress_bar(self.dl, parent=getattr(self, 'mbar', None), leave=False)
self.pbar.update(0)
def after_fit(self):
if getattr(self, 'mbar', False):
self.mbar.on_iter_end()
delattr(self, 'mbar')
if hasattr(self, 'old_logger'): self.learn.logger = self.old_logger
def _write_stats(self, log):
if getattr(self, 'mbar', False): self.mbar.write([f'{l:.6f}' if isinstance(l, float) else str(l) for l in log], table=True)
_docs = dict(before_fit="Setup the master bar over the epochs",
before_epoch="Update the master bar",
before_train="Launch a progress bar over the training dataloader",
before_validate="Launch a progress bar over the validation dataloader",
after_train="Close the progress bar over the training dataloader",
after_validate="Close the progress bar over the validation dataloader",
after_batch="Update the current progress bar",
after_fit="Close the master bar")
if not hasattr(defaults, 'callbacks'): defaults.callbacks = [TrainEvalCallback, Recorder, ProgressCallback]
elif ProgressCallback not in defaults.callbacks: defaults.callbacks.append(ProgressCallback)
learn = synth_learner()
learn.fit(5)
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 8.960221 | 8.501486 | 00:00 |
1 | 7.650368 | 5.475908 | 00:00 |
2 | 6.193127 | 3.202425 | 00:00 |
3 | 4.902714 | 1.781969 | 00:00 |
4 | 3.847687 | 0.968699 | 00:00 |
#export
@patch
@contextmanager
def no_bar(self:Learner):
"Context manager that deactivates the use of progress bars"
has_progress = hasattr(self, 'progress')
if has_progress: self.remove_cb(self.progress)
try: yield self
finally:
if has_progress: self.add_cb(ProgressCallback())
learn = synth_learner()
with learn.no_bar(): learn.fit(5)
[0, 16.774219512939453, 16.614517211914062, '00:00'] [1, 14.62364387512207, 11.538640975952148, '00:00'] [2, 12.198295593261719, 7.462512016296387, '00:00'] [3, 9.962362289428711, 4.619643688201904, '00:00'] [4, 8.045241355895996, 2.791717052459717, '00:00']
#hide
#Check validate works without any training
def tst_metric(out, targ): return F.mse_loss(out, targ)
learn = synth_learner(n_trn=5, metrics=tst_metric)
preds,targs = learn.validate()
#hide
#Check get_preds works without any training
learn = synth_learner(n_trn=5, metrics=tst_metric)
preds,targs = learn.validate()
show_doc(ProgressCallback.before_fit)
ProgressCallback.before_fit
[source]
ProgressCallback.before_fit
()
Setup the master bar over the epochs
show_doc(ProgressCallback.before_epoch)
show_doc(ProgressCallback.before_train)
ProgressCallback.before_train
[source]
ProgressCallback.before_train
()
Launch a progress bar over the training dataloader
show_doc(ProgressCallback.before_validate)
ProgressCallback.before_validate
[source]
ProgressCallback.before_validate
()
Launch a progress bar over the validation dataloader
show_doc(ProgressCallback.after_batch)
show_doc(ProgressCallback.after_train)
ProgressCallback.after_train
[source]
ProgressCallback.after_train
()
Close the progress bar over the training dataloader
show_doc(ProgressCallback.after_validate)
ProgressCallback.after_validate
[source]
ProgressCallback.after_validate
()
Close the progress bar over the validation dataloader
show_doc(ProgressCallback.after_fit)
# export
class ShowGraphCallback(Callback):
"Update a graph of training and validation loss"
order,run_valid=65,False
def before_fit(self):
self.run = not hasattr(self.learn, 'lr_finder') and not hasattr(self, "gather_preds")
if not(self.run): return
self.nb_batches = []
assert hasattr(self.learn, 'progress')
def after_train(self): self.nb_batches.append(self.train_iter)
def after_epoch(self):
"Plot validation loss in the pbar graph"
if not self.nb_batches: return
rec = self.learn.recorder
iters = range_of(rec.losses)
val_losses = [v[1] for v in rec.values]
x_bounds = (0, (self.n_epoch - len(self.nb_batches)) * self.nb_batches[0] + len(rec.losses))
y_bounds = (0, max((max(Tensor(rec.losses)), max(Tensor(val_losses)))))
self.progress.mbar.update_graph([(iters, rec.losses), (self.nb_batches, val_losses)], x_bounds, y_bounds)
#slow
learn = synth_learner(cbs=ShowGraphCallback())
learn.fit(5)
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 11.676161 | 8.622957 | 00:00 |
1 | 10.183996 | 6.069672 | 00:00 |
2 | 8.519948 | 3.890609 | 00:00 |
3 | 6.959560 | 2.382288 | 00:00 |
4 | 5.621220 | 1.414858 | 00:00 |
learn.predict(torch.tensor([[0.1]]))
(tensor([1.9139]), tensor([1.9139]), tensor([1.9139]))
# export
class CSVLogger(Callback):
"Log the results displayed in `learn.path/fname`"
order=60
def __init__(self, fname='history.csv', append=False):
self.fname,self.append = Path(fname),append
def read_log(self):
"Convenience method to quickly access the log."
return pd.read_csv(self.path/self.fname)
def before_fit(self):
"Prepare file with metric names."
if hasattr(self, "gather_preds"): return
self.path.parent.mkdir(parents=True, exist_ok=True)
self.file = (self.path/self.fname).open('a' if self.append else 'w')
self.file.write(','.join(self.recorder.metric_names) + '\n')
self.old_logger,self.learn.logger = self.logger,self._write_line
def _write_line(self, log):
"Write a line with `log` and call the old logger."
self.file.write(','.join([str(t) for t in log]) + '\n')
self.file.flush()
os.fsync(self.file.fileno())
self.old_logger(log)
def after_fit(self):
"Close the file and clean up."
if hasattr(self, "gather_preds"): return
self.file.close()
self.learn.logger = self.old_logger
The results are appended to an existing file if append
, or they overwrite it otherwise.
learn = synth_learner(cbs=CSVLogger())
learn.fit(5)
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 10.500990 | 8.331024 | 00:00 |
1 | 9.115092 | 5.783391 | 00:00 |
2 | 7.573916 | 3.695323 | 00:00 |
3 | 6.161108 | 2.222861 | 00:00 |
4 | 4.948495 | 1.308835 | 00:00 |
show_doc(CSVLogger.read_log)
df = learn.csv_logger.read_log()
test_eq(df.columns.values, learn.recorder.metric_names)
for i,v in enumerate(learn.recorder.values):
test_close(df.iloc[i][:3], [i] + v)
os.remove(learn.path/learn.csv_logger.fname)
show_doc(CSVLogger.before_fit)
show_doc(CSVLogger.after_fit)
#hide
from nbdev.export import notebook2script
notebook2script()
Converted 00_torch_core.ipynb. Converted 01_layers.ipynb. Converted 01a_losses.ipynb. Converted 02_data.load.ipynb. Converted 03_data.core.ipynb. Converted 04_data.external.ipynb. Converted 05_data.transforms.ipynb. Converted 06_data.block.ipynb. Converted 07_vision.core.ipynb. Converted 08_vision.data.ipynb. Converted 09_vision.augment.ipynb. Converted 09b_vision.utils.ipynb. Converted 09c_vision.widgets.ipynb. Converted 10_tutorial.pets.ipynb. Converted 10b_tutorial.albumentations.ipynb. Converted 11_vision.models.xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_callback.core.ipynb. Converted 13a_learner.ipynb. Converted 13b_metrics.ipynb. Converted 14_callback.schedule.ipynb. Converted 14a_callback.data.ipynb. Converted 15_callback.hook.ipynb. Converted 15a_vision.models.unet.ipynb. Converted 16_callback.progress.ipynb. Converted 17_callback.tracker.ipynb. Converted 18_callback.fp16.ipynb. Converted 18a_callback.training.ipynb. Converted 18b_callback.preds.ipynb. Converted 19_callback.mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision.learner.ipynb. Converted 22_tutorial.imagenette.ipynb. Converted 23_tutorial.vision.ipynb. Converted 24_tutorial.siamese.ipynb. Converted 24_vision.gan.ipynb. Converted 30_text.core.ipynb. Converted 31_text.data.ipynb. Converted 32_text.models.awdlstm.ipynb. Converted 33_text.models.core.ipynb. Converted 34_callback.rnn.ipynb. Converted 35_tutorial.wikitext.ipynb. Converted 36_text.models.qrnn.ipynb. Converted 37_text.learner.ipynb. Converted 38_tutorial.text.ipynb. Converted 39_tutorial.transformers.ipynb. Converted 40_tabular.core.ipynb. Converted 41_tabular.data.ipynb. Converted 42_tabular.model.ipynb. Converted 43_tabular.learner.ipynb. Converted 44_tutorial.tabular.ipynb. Converted 45_collab.ipynb. Converted 46_tutorial.collab.ipynb. Converted 50_tutorial.datablock.ipynb. Converted 60_medical.imaging.ipynb. Converted 61_tutorial.medical_imaging.ipynb. Converted 65_medical.text.ipynb. Converted 70_callback.wandb.ipynb. Converted 71_callback.tensorboard.ipynb. Converted 72_callback.neptune.ipynb. Converted 73_callback.captum.ipynb. Converted 97_test_utils.ipynb. Converted 99_pytorch_doc.ipynb. Converted dev-setup.ipynb. Converted index.ipynb. Converted quick_start.ipynb. Converted tutorial.ipynb.