The fast.ai
library has a callback to track training metrics history. However, the history is reported via console, or Jupyter widget, and there are no callbacks to store these results into CSV format. In this notebook, the author proposes his approach to implement a callback similar to CSVLogger from Keras library which will save tracked metrics into persistent file.
%reload_ext autoreload
%autoreload 1
from fastai import *
from fastai.torch_core import *
from fastai.vision import *
from fastai.metrics import *
from torchvision.models import resnet18
@dataclass
class CSVLogger(LearnerCallback):
"A `LearnerCallback` that "
filename:str='history.csv'
def __post_init__(self):
self.path = Path(self.filename)
self.file = None
@property
def header(self):
return self.learn.recorder.names
def read_logged_file(self):
return pd.read_csv(self.path)
def on_train_begin(self, metrics_names:StrList, **kwargs:Any)->None:
self.path.parent.mkdir(parents=True, exist_ok=True)
self.file = self.path.open('w')
self.file.write(','.join(self.header) + '\n')
def on_epoch_end(self, epoch:int, smooth_loss:Tensor, last_metrics:MetricsList, **kwargs:Any)->bool:
self.write_stats([epoch, smooth_loss] + last_metrics)
def on_train_end(self, **kwargs:Any)->None:
self.file.flush()
self.file.close()
def write_stats(self, stats:TensorOrNumList)->None:
stats = [str(stat) if isinstance(stat, int) else f'{stat:.6f}'
for name,stat in zip(self.header,stats)]
str_stats = ','.join(stats)
self.file.write(str_stats + '\n')
Let's train MNIST classifier and track its metrics. All the metrics listed in metrics
array, and also epoch number, train and valid loss should be saved into file. Then we can read this file and process somehow.
path = untar_data(URLs.MNIST_TINY)
data = ImageDataBunch.from_folder(path)
learn = ConvLearner(data, resnet18, metrics=[accuracy, error_rate])
cb = CSVLogger(learn)
learn.fit(3, callbacks=[cb])
VBox(children=(HBox(children=(IntProgress(value=0, max=3), HTML(value='0.00% [0/3 00:00<00:00]'))), HTML(value…
Total time: 00:03 epoch train loss valid loss accuracy error_rate 1 0.410295 0.839821 0.608011 0.391989 (00:01) 2 0.333742 0.425074 0.844063 0.155937 (00:00) 3 0.237944 0.391234 0.845494 0.154506 (00:00)
log_df = cb.read_logged_file()
log_df
epoch | train loss | valid loss | accuracy | error_rate | |
---|---|---|---|---|---|
0 | 1 | 0.410295 | 0.839821 | 0.608011 | 0.391989 |
1 | 2 | 0.333742 | 0.425074 | 0.844063 | 0.155937 |
2 | 3 | 0.237944 | 0.391234 | 0.845494 | 0.154506 |
The tests are present in in test_logger.py file and could be invoked with command:
$ python -m pytest test_logger.py
To keep all PRs code in a single place, here is the content of aforementioned file:
from unittest import mock
import pytest
from fastai import *
from fastai.vision import *
from fastai.metrics import *
from torchvision.models import resnet18
from logger import CSVLogger
def test_callback_has_required_properties_after_init(history):
cb = CSVLogger(mock.Mock(), filename=history)
assert cb.filename
assert not cb.path.exists()
assert cb.file is None
def test_callback_writes_learn_metrics_during_training(classifier_and_logger):
n_epochs = 3
classifier, cb = classifier_and_logger
classifier.fit(n_epochs, callbacks=[cb])
log_df = cb.read_logged_file()
assert cb.path.exists()
assert cb.file.closed
assert not log_df.empty
assert len(log_df) == n_epochs
assert classifier.recorder.names == log_df.columns.tolist()
def test_callback_works_with_fit_one_cycle_method(classifier_and_logger, monkeypatch):
class MockFit:
def __init__(self):
self.n_callbacks = 0
def __call__(self, *args, **kwargs):
self.n_callbacks = len(kwargs['callbacks'])
classifier, cb = classifier_and_logger
mock_fit = MockFit()
setattr(classifier, 'fit', mock_fit)
classifier.fit_one_cycle(1, callbacks=[cb])
assert mock_fit.n_callbacks == 2
@pytest.fixture
def history(tmpdir): return tmpdir.join('history.csv')
@pytest.fixture
def classifier(history):
path = untar_data(URLs.MNIST_TINY)
bunch = ImageDataBunch.from_folder(path)
learn = ConvLearner(bunch, resnet18)
return learn
@pytest.fixture
def classifier_and_logger(classifier, history):
classifier.metrics = [accuracy, error_rate]
cb = CSVLogger(classifier, filename=history)
return classifier, cb
@pytest.fixture
def patched_fit(monkeypatch):
monkeypatch.setattr()