The fast.ai
library has a callback to track training metrics history. However, the history is reported via console, or Jupyter widget, and there are no callbacks to store these results into CSV format. In this notebook, the author proposes his approach to implement a callback similar to CSVLogger from Keras library which will save tracked metrics into persistent file.
%reload_ext autoreload
%autoreload 2
from fastai import *
from fastai.torch_core import *
from fastai.vision import *
from fastai.metrics import *
from torchvision.models import resnet18
@dataclass
class CSVLogger(LearnerCallback):
"A `LearnerCallback` that saves history of training metrics into CSV file."
filename: str = 'history'
def __post_init__(self):
self.path = self.learn.path/f'{self.filename}.csv'
self.file = None
@property
def header(self):
return self.learn.recorder.names
def read_logged_file(self):
return pd.read_csv(self.path)
def on_train_begin(self, metrics_names: StrList, **kwargs: Any) -> None:
self.path.parent.mkdir(parents=True, exist_ok=True)
self.file = self.path.open('w')
self.file.write(','.join(self.header) + '\n')
def on_epoch_end(self, epoch: int, smooth_loss: Tensor, last_metrics: MetricsList, **kwargs: Any) -> bool:
self.write_stats([epoch, smooth_loss] + last_metrics)
def on_train_end(self, **kwargs: Any) -> None:
self.file.flush()
self.file.close()
def write_stats(self, stats: TensorOrNumList) -> None:
stats = [str(stat) if isinstance(stat, int) else f'{stat:.6f}'
for name, stat in zip(self.header, stats)]
str_stats = ','.join(stats)
self.file.write(str_stats + '\n')
Let's train MNIST classifier and track its metrics. All the metrics listed in metrics
array, and also epoch number, train and valid loss should be saved into file. Then we can read this file and process somehow.
path = untar_data(URLs.MNIST_TINY)
data = ImageDataBunch.from_folder(path)
learn = ConvLearner(data, resnet18, metrics=[accuracy, error_rate])
cb = CSVLogger(learn)
learn.fit(3, callbacks=[cb])
VBox(children=(HBox(children=(IntProgress(value=0, max=3), HTML(value='0.00% [0/3 00:00<00:00]'))), HTML(value…
Total time: 00:03 epoch train loss valid loss accuracy error_rate 1 0.410295 0.839821 0.608011 0.391989 (00:01) 2 0.333742 0.425074 0.844063 0.155937 (00:00) 3 0.237944 0.391234 0.845494 0.154506 (00:00)
log_df = cb.read_logged_file()
log_df
epoch | train loss | valid loss | accuracy | error_rate | |
---|---|---|---|---|---|
0 | 1 | 0.410295 | 0.839821 | 0.608011 | 0.391989 |
1 | 2 | 0.333742 | 0.425074 | 0.844063 | 0.155937 |
2 | 3 | 0.237944 | 0.391234 | 0.845494 | 0.154506 |
The tests are present in in test_logger.py file and could be invoked with command:
$ python -m pytest test_logger.py
To keep all PRs code in a single place, here is the content of aforementioned file:
from io import StringIO
from contextlib import redirect_stdout
import pytest
from fastai import *
from fastai.vision import *
from fastai.metrics import *
from fastprogress import fastprogress
from logger import CSVLogger
def test_callback_has_required_properties_after_init(classifier):
cb = CSVLogger(classifier)
assert cb.filename
assert not cb.path.exists()
assert cb.learn is classifier
assert cb.file is None
def test_callback_writes_learn_metrics_during_training(classifier_and_logger):
n_epochs = 3
classifier, cb = classifier_and_logger
classifier.fit(n_epochs, callbacks=[cb])
log_df = cb.read_logged_file()
assert cb.path.exists()
assert cb.file.closed
assert not log_df.empty
assert len(log_df) == n_epochs
assert classifier.recorder.names == log_df.columns.tolist()
def test_callback_written_metrics_are_equal_to_reported_via_stdout(classifier_and_logger, no_bar):
n_epochs = 3
classifier, cb = classifier_and_logger
buffer = StringIO()
with redirect_stdout(buffer):
classifier.fit(n_epochs, callbacks=[cb])
csv_df = cb.read_logged_file()
stdout_df = convert_into_dataframe(buffer)
pd.testing.assert_frame_equal(csv_df, stdout_df)
@pytest.fixture
def classifier(tmpdir):
path = untar_data(URLs.MNIST_TINY)
bunch = ImageDataBunch.from_folder(path)
model_path = str(tmpdir.join('classifier'))
learn = Learner(bunch, simple_cnn((3, 10, 10)), path=model_path)
return learn
@pytest.fixture
def classifier_and_logger(classifier):
classifier.metrics = [accuracy, error_rate]
cb = CSVLogger(classifier)
return classifier, cb
@pytest.fixture
def no_bar():
fastprogress.NO_BAR = True
yield
fastprogress.NO_BAR = False
def convert_into_dataframe(buffer):
lines = buffer.getvalue().split('\n')
header, *lines = [l.strip() for l in lines if l]
header = header.split()
floats = [[float(x) for x in line.split()] for line in lines]
records = [dict(zip(header, metrics_list)) for metrics_list in floats]
df = pd.DataFrame(records, columns=header)
df['epoch'] = df['epoch'].astype(int)
return df