In our previous notebook we implemented data loader classes from scratch so that our training cycle could take advantage of mini-batch training.
We observed that training via mini-batches allows us to leverage the parallel processing capabilities of Nvidia GPUs, which means we can run forward passes on several, 64 in our case, training inputs simultaneously. This obviously allows us to train our models much faster than if we were forced to process only one input at a time.
In this notebook we demonstrate how to implement a callback system from scratch, and use it to hook into our model at various points during the training cycle.
Fundamentally, callbacks allow us to observe and if we choose, influence how our model is training, all while the training cycle is still ongoing. Useful things we might use callbacks to accomplish include:
Simply put, callbacks are an indispensable tool for the training of deep neural networks.
Virtually all the code that appears in this notebook is the creation of Sylvain Gugger and Jeremy Howard. The original version of this notebook that they made for the course lecture can be found here. I simply re-typed, line-by-line, the pieces of logic necessary to implement the functionality that their notebook demonstrated. In some cases I changed the order of code cells and or variable names so as to fit an organization and style that seemed more intuitive to me. Any and all mistakes are my own.
On the other hand, all long-form text explanations in this notebook are solely my own creation. Writing extensive descriptions of the concepts and code in plain and simple English forces me to make sure that I actually understand how they work.
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
#export
from exports.nb_03 import *
Before we implement the classes and methods that will allow us to handle callbacks, let's quickly refactor the code we used to create our data sets and loaders, as well as the code we used to initialize our model.
As with all previous notebooks so far, we're using the MNIST dataset.
x_train,y_train,x_valid,y_valid = get_data()
train_ds,valid_ds = Dataset(x_train,y_train), Dataset(x_valid,y_valid)
nh,bs = 50,64 # hidden layer size, batch size
c = y_train.max().item() + 1 # number of classes
loss_func = F.cross_entropy
Up until now, our approach has been to create a method called fit()
which defined how our model's training loop would run. We then called it whenever we wanted to initiate a training cycle. It looked like this:
fit(epochs, model, loss_func, opt, train_dl, valid_dl)
What if we stored the model, loss_func, opt, train_dl, valid_dl
parameters inside another class called Learner()
? This would not only make our fit()
call much simpler, but if we made Learner
objects mutable, would have the nice side-effect of allowing any adjustments made to the Learner
while the model is training be immediately seen inside the training loop.
We could, for example, update the state of the learning rate value stored inside the Learner
object at a particular point during the training cycle, and our model would immediately begin training at the updated learning rate.
#export
class DataBunch():
def __init__(self, train_dl, valid_dl, c=None):
self.train_dl,self.valid_dl,self.c = train_dl,valid_dl,c
@property
def train_ds(self): return self.train_dl.dataset
@property
def valid_ds(self): return self.valid_dl.dataset
The DataBunch()
class gives us a handy way to manage both the train and validation datasets/loaders, all under one roof.
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c) # get_dls() defined in notebook03
#export
def get_model(data, lr=0.5, nh=50):
m = data.train_ds.x.shape[1] # Size of inputs
model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,data.c))
return model, optim.SGD(model.parameters(), lr=lr)
class Learner():
def __init__(self, model, opt, loss_func, data):
self.model,self.opt,self.loss_func,self.data=model,opt,loss_func,data
learn = Learner(*get_model(data), loss_func, data)
def fit(epochs, learn):
for epoch in range(epochs):
learn.model.train() # put the model in train mode
for xb,yb in learn.data.train_dl:
loss = learn.loss_func(learn.model(xb), yb)
loss.backward()
learn.opt.step()
learn.opt.zero_grad()
learn.model.eval() # put model in inference mode
with torch.no_grad():
tot_loss, tot_acc = 0., 0.
for xb,yb in learn.data.valid_dl:
pred = learn.model(xb)
# Technically not proper to calculate tot_loss & tot_acc
# in this way because we don't ensure each val batch has
# the same size (by setting drop_last=True for the val
# data loader). Last val batch size is smaller than all
# prev val batches.
tot_loss += learn.loss_func(pred, yb)
tot_acc += accuracy(pred, yb)
nv = len(learn.data.valid_dl)
print(epoch, tot_loss/nv, tot_acc/nv)
return tot_loss/nv, tot_acc/nv
loss, acc = fit(1, learn)
0 tensor(0.2717) tensor(0.9133)
We can refactor our fit()
training loop so that it is easy to identify when a single batch is trained, and also when all batches are trained.
This simpler structure will allow us to easily specify where we wire-in our various callbacks. Each of the below three functions has a cb
parameter to accept a Callback
object. Again, the ultimate goal of all of this is to make our training loop more flexible.
def one_batch(xb, yb, cb):
if not cb.begin_batch(xb,yb): return
loss = cb.learn.loss_func(cb.learn.model(xb), yb)
if not cb.after_loss(loss): return
loss.backward()
if cb.after_backward(): cb.learn.opt.step()
if cb.after_step(): cb.learn.opt.zero_grad()
def all_batches(dl, cb):
for xb,yb in dl:
one_batch(xb, yb, cb)
if cb.do_stop(): return
def fit(epochs, learn, cb):
if not cb.begin_fit(learn): return
for epoch in range(epochs):
if not cb.begin_epoch(epoch): continue
all_batches(learn.data.train_dl, cb)
if cb.begin_validate():
with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
if cb.do_stop() or not cb.after_epoch(): break
cb.after_fit()
class Callback():
def begin_fit(self, learn):
self.learn = learn
def after_fit(self): return True
def begin_epoch(self, epoch):
self.epoch = epoch
return True
def begin_validate(self): return True
def after_epoch(self): return True
def begin_batch(self, xb, yb):
self.xb,self.yb = xb,yb
return True
def after_loss(self, loss):
self.loss = loss
return True
def after_backward(self): return True
def after_step(self): return True
class CallbackHandler():
def __init__(self,cbs=None):
self.cbs = cbs if cbs else []
def begin_fit(self, learn):
self.learn, self.in_train = learn,True
learn.stop = False
res = True
for cb in self.cbs: res = res and cb.begin_fit(learn)
return res
def after_fit(self):
res = not self.in_train
for cb in self.cbs: res = res and cb.after_fit()
return res
def begin_epoch(self, epoch):
learn.model.train()
self.in_train = True
res = True
for cb in self.cbs: res = res and cb.begin_epoch(epoch)
return res
def begin_validate(self):
self.learn.model.eval()
self.in_train=False
res = True
for cb in self.cbs: res = res and cb.begin_validate()
return res
def after_epoch(self):
res = True
for cb in self.cbs: res = res and cb.after_epoch()
return res
def begin_batch(self, xb, yb):
res = True
for cb in self.cbs: res = res and cb.begin_batch(xb, yb)
return res
def after_loss(self, loss):
res = self.in_train
for cb in self.cbs: res = res and cb.after_loss(loss)
return res
def after_backward(self):
res = True
for cb in self.cbs: res = res and cb.after_backward()
return res
def after_step(self):
res = True
for cb in self.cbs: res = res and cb.after_step()
return res
def do_stop(self):
try: return learn.stop
finally: learn.stop = False
class TestCallback(Callback):
def begin_fit(self, learn):
super().begin_fit(learn)
self.n_iters = 0
return True
def after_step(self):
self.n_iters += 1
print(self.n_iters)
if self.n_iters>=10: learn.stop = True
return True
fit(1, learn, cb=CallbackHandler([TestCallback()]))
1 2 3 4 5 6 7 8 9 10
The above structure is very similar to how version 1.0 of the fastai library implements callbacks, with the exception being that fastai's callback handler can also change and return xb
, yb
, and loss
.
Now while the above architecture is a nice first attempt at creating a workable callback handler, there are ways we can make things more simple and flexible. It would be more straightforward if a single class had access to everything and could thus change anything at any time.
After all, seeing as how we're passing cb
to each of the one_batch()
, all_batches()
, and fit()
functions, it would make sense to store them all under one class.
Runner()
Class¶In fact, this is what we will do shortly. We will create a class called Runner()
, which contains all the methods that compose our model training cycle, as well as the optimizer, model, loss function, and data.
First we'll rewrite our Callback()
class to be compatible with the soon-to-be-implemented Runner()
class:
#export
import re
# Helper function to convert the formatting of callback names
# so they can be displayed the way we want: all lower-case,
# with underscores in-between each word.
_camel_re1 = re.compile('(.)([A-Z][a-z])')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
def camel2snake(name):
s1 = re.sub(_camel_re1, r'\1\2', name)
return re.sub(_camel_re2, r'\1_\2', s1).lower()
class Callback():
_order=0
def set_runner(self, run): self.run=run
def __getattr__(self, k): return getattr(self.run, k)
@property
def name(self):
name = re.sub(r'Callback$', '', self.__class__.__name__)
return camel2snake(name or 'callback')
What will be the core callback of our training cycle. It is responsible for switching in between train and eval mode, as well as maintaining a count of the iterations that have elapsed during an epoch. This callback will always be called by default by our Runner()
class:
#export
class TrainEvalCallback(Callback):
def begin_fit(self):
self.run.n_epochs=0.
self.run.n_iter=0
def after_batch(self):
if not self.in_train: return
self.run.n_epochs += 1./self.iters
self.run.n_iter += 1
def begin_epoch(self):
self.run.n_epochs=self.epoch
self.model.train()
self.run.in_train=True
def begin_validate(self):
self.model.eval()
self.run.in_train=False
A quick test of what we just wrote:
cbname = 'TrainEvalCallback'
camel2snake(cbname)
'train_eval_callback'
TrainEvalCallback().name
'train_eval'
A quick helper function that transforms inputs into lists:
#export
from typing import *
def listify(o):
if o is None: return []
if isinstance(o, list): return o
if isinstance(o, str): return [o]
if isinstance(o, Iterable): return list(o)
return [o]
Finally, our new Runner()
class:
#export
class Runner():
def __init__(self, cbs=None, cb_funcs=None):
cbs = listify(cbs)
for cbf in listify(cb_funcs):
cb = cbf()
setattr(self, cb.name, cb)
cbs.append(cb)
self.stop, self.cbs = False, [TrainEvalCallback()] + cbs
@property
def opt(self): return self.learn.opt
@property
def model(self): return self.learn.model
@property
def loss_func(self): return self.learn.loss_func
@property
def data(self): return self.learn.data
def one_batch(self, xb, yb):
self.xb, self.yb = xb, yb
if self('begin_batch'): return
self.pred = self.model(self.xb)
if self('after_pred'): return
self.loss = self.loss_func(self.pred, self.yb)
if self('after_loss') or not self.in_train: return
self.loss.backward()
if self('after_backward'): return
self.opt.step()
if self('after_step'): return
self.opt.zero_grad()
def all_batches(self, dl):
self.iters = len(dl)
for xb,yb in dl:
if self.stop: break
self.one_batch(xb,yb)
self('after_batch')
self.stop=False
def fit(self, epochs, learn):
self.epochs, self.learn, self.loss= epochs, learn, tensor(0.)
try:
for cb in self.cbs: cb.set_runner(self)
if self('begin_fit'): return
for epoch in range(epochs):
self.epoch = epoch
if not self('begin_epoch'): self.all_batches(self.data.train_dl)
with torch.no_grad():
if not self('begin_validate'): self.all_batches(self.data.valid_dl)
if self('after_epoch'): break
finally:
self('after_fit')
self.learn = None
def __call__(self, cb_name):
for cb in sorted(self.cbs, key=lambda x: x._order):
f = getattr(cb, cb_name, None)
if f and f(): return True
return False
AvgStats()
¶The second callback we will create is also of core importance: it will calculate and display the average loss and evaluation metrics during training (unlike how we'd been doing it up until now, this implementation will display correct avg loss/metrics regardless of whether batch size is constant of varies across iterations):
#export
class AvgStats():
def __init__(self, metrics, in_train): self.metrics, self.in_train = listify(metrics), in_train
def reset(self):
self.tot_loss, self.count = 0., 0
self.tot_mets = [0.] * len(self.metrics)
@property
def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
@property
def avg_stats(self): return [o/self.count for o in self.all_stats]
def __repr__(self):
if not self.count: return ""
return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"
def accumulate(self, run):
bn = run.xb.shape[0]
self.tot_loss += run.loss * bn
self.count += bn
for i,m in enumerate(self.metrics):
self.tot_mets[i] += m(run.pred, run.yb) * bn
class AvgStatsCallback(Callback):
def __init__(self, metrics):
self.train_stats, self.valid_stats = AvgStats(metrics,in_train=True), AvgStats(metrics,in_train=False)
def begin_epoch(self):
self.train_stats.reset()
self.valid_stats.reset()
def after_loss(self):
stats = self.train_stats if self.in_train else self.valid_stats
with torch.no_grad(): stats.accumulate(self.run)
def after_epoch(self):
print(self.train_stats)
print(self.valid_stats)
Let's try it out!
learn = Learner(*get_model(data), loss_func, data)
Let's suppose that we use accuracy as our evaluation metric. Here's one way to tell our AvgStatsCallback()
class that this is what we want to do:
stats = AvgStatsCallback([accuracy])
run = Runner(cbs=stats)
run.fit(2, learn)
train: [0.31132888671875, tensor(0.9030)] valid: [0.18311549072265626, tensor(0.9440)] train: [0.139973671875, tensor(0.9580)] valid: [0.1413156494140625, tensor(0.9571)]
loss, acc = stats.valid_stats.avg_stats
assert acc > 0.9
loss, acc
(0.1413156494140625, tensor(0.9571))
We can also use the partial
method to pass the method that calculates accuracy to our AvgStatsCallback
class object:
#export
from functools import partial
acc_cbf = partial(AvgStatsCallback, accuracy)
run = Runner(cb_funcs=acc_cbf)
run.fit(1, learn)
train: [0.10387189453125, tensor(0.9687)] valid: [0.114280712890625, tensor(0.9661)]
Finally, if you try typing out the line below, you'll see that Jupyter lets us get tab-completion for dynamic code: each time you type a .
, press the tab key to see a pop-up of all the possible methods that could be called next.
run.avg_stats.valid_stats.avg_stats
[0.114280712890625, tensor(0.9661)]
!python notebook2script_my_reimplementation.py 04_callbacks_my_reimplementation.ipynb
Converted 04_callbacks_my_reimplementation.ipynb to nb_04.py