# default_exp learner
#export
from local.test import *
from local.data.all import *
from local.optimizer import *
from local.notebook.showdoc import *
#export
_all_ = ['CancelFitException', 'CancelEpochException', 'CancelTrainException', 'CancelValidException', 'CancelBatchException']
Basic class for handling the training loop
We'll use the following for testing purposes (a basic linear regression problem):
from torch.utils.data import TensorDataset
def synth_dbunch(a=2, b=3, bs=16, n_train=10, n_valid=2, cuda=False):
def get_data(n):
x = torch.randn(int(bs*n))
return TensorDataset(x, a*x + b + 0.1*torch.randn(int(bs*n)))
train_ds = get_data(n_train)
valid_ds = get_data(n_valid)
tfms = Cuda() if cuda else None
train_dl = TfmdDL(train_ds, bs=bs, shuffle=True, after_batch=tfms, num_workers=0)
valid_dl = TfmdDL(valid_ds, bs=bs, after_batch=tfms, num_workers=0)
return DataBunch(train_dl, valid_dl)
class RegModel(Module):
def __init__(self): self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1))
def forward(self, x): return x*self.a + self.b
#export
class Callback(GetAttr):
"Basic class handling tweaks of the training loop by changing a `Learner` in various events"
_default,learn = 'learn',None
def __repr__(self): return type(self).__name__
def __call__(self, event_name):
"Call `self.{event_name}` if it's defined"
getattr(self, event_name, noop)()
@property
def name(self):
"Name of the `Callback`, camel-cased and with '*Callback*' removed"
return class2attr(self, 'Callback')
The training loop is defined in Learner
a bit below and consists in a minimal set of instructions: looping through the data we:
Any tweak of this training loop is defined in a Callback
to avoid over-complicating the code of the training loop, and to make it easy to mix and match different techniques (since they'll be defined in different callbacks). A callback can implement actions on the following events:
begin_fit
: called before doing anything, ideal for initial setup.begin_epoch
: called at the beginning of each epoch, useful for any behavior you need to reset at each epoch.begin_train
: called at the beginning of the training part of an epoch.begin_batch
: called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance).after_pred
: called after computing the output of the model on the batch. It can be used to change that output before it's fed to the loss.after_loss
: called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).after_backward
: called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update (gradient clipping for instance).after_step
: called after the step and before the gradients are zeroed.after_batch
: called at the end of a batch, for any clean-up before the next one.after_train
: called at the end of the training phase of an epoch.begin_validate
: called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation.after_validate
: called at the end of the validation part of an epoch.after_epoch
: called at the end of an epoch, for any clean-up before the next one.after_fit
: called at the end of training, for final clean-up.show_doc(Callback.__call__)
tst_cb = Callback()
tst_cb.call_me = lambda: print("maybe")
test_stdout(lambda: tst_cb("call_me"), "maybe")
show_doc(Callback.__getattr__)
GetAttr.__getattr__
[source]
GetAttr.getattr
(k
)
This is a shortcut to avoid having to write self.learn.bla
for any bla
attribute we seek, and just write self.bla
.
mk_class('TstLearner', 'a')
class TstCallback(Callback):
def batch_begin(self): print(self.a)
learn,cb = TstLearner(1),TstCallback()
cb.learn = learn
test_stdout(lambda: cb('batch_begin'), "1")
Note that it only works to get the value of the attribute, if you want to change it, you have to manually access it with self.learn.bla
. In the example below, self.a += 1
creates an a
attribute of 2 in the callback instead of setting the a
of the learner to 2:
class TstCallback(Callback):
def batch_begin(self): self.a += 1
learn,cb = TstLearner(1),TstCallback()
cb.learn = learn
cb('batch_begin')
test_eq(cb.a, 2)
test_eq(cb.learn.a, 1)
A proper version needs to write self.learn.a = self.a + 1
:
class TstCallback(Callback):
def batch_begin(self): self.learn.a = self.a + 1
learn,cb = TstLearner(1),TstCallback()
cb.learn = learn
cb('batch_begin')
test_eq(cb.learn.a, 2)
show_doc(Callback.name, name='Callback.name')
test_eq(TstCallback().name, 'tst')
class ComplicatedNameCallback(Callback): pass
test_eq(ComplicatedNameCallback().name, 'complicated_name')
#export
class TrainEvalCallback(Callback):
"`Callback` that tracks the number of iterations done and properly sets training/eval mode"
def begin_fit(self):
"Set the iter and epoch counters to 0, put the model and the right device"
self.learn.train_iter,self.learn.pct_train = 0,0.
self.model.to(self.dbunch.device)
def after_batch(self):
"Update the iter counter (in training mode)"
if not self.training: return
self.learn.pct_train += 1./(self.n_iter*self.n_epoch)
self.learn.train_iter += 1
def begin_train(self):
"Set the model in training mode"
self.learn.pct_train=self.epoch/self.n_epoch
self.model.train()
self.learn.training=True
def begin_validate(self):
"Set the model in validation mode"
self.model.eval()
self.learn.training=False
show_doc(TrainEvalCallback, title_level=3)
This Callback
is automatically added in every Learner
at initialization.
#hide
#test of the TrainEvalCallback below in Learner.fit
show_doc(TrainEvalCallback.begin_fit)
TrainEvalCallback.begin_fit
[source]
TrainEvalCallback.begin_fit
()
Set the iter and epoch counters to 0, put the model and the right device
show_doc(TrainEvalCallback.after_batch)
TrainEvalCallback.after_batch
[source]
TrainEvalCallback.after_batch
()
Update the iter counter (in training mode)
show_doc(TrainEvalCallback.begin_train)
TrainEvalCallback.begin_train
[source]
TrainEvalCallback.begin_train
()
Set the model in training mode
show_doc(TrainEvalCallback.begin_validate)
TrainEvalCallback.begin_validate
[source]
TrainEvalCallback.begin_validate
()
Set the model in validation mode
#export
class GatherPredsCallback(Callback):
"`Callback` that saves the predictions and targets, optionally `with_loss`"
def __init__(self, with_input=False, with_loss=False): store_attr(self, "with_input,with_loss")
def begin_batch(self):
if self.with_input: self.inputs.append((to_detach(self.xb)))
def begin_validate(self):
"Initialize containers"
self.preds,self.targets = [],[]
if self.with_input: self.inputs=[]
if self.with_loss: self.losses = []
def after_batch(self):
"Save predictions, targets and potentially losses"
self.preds.append(to_detach(self.pred))
self.targets.append(to_detach(self.yb))
if self.with_loss:
bs = find_bs(self.yb)
loss = self.loss if self.loss.numel() == bs else self.loss.view(bs,-1).mean(1)
self.losses.append(to_detach(loss))
show_doc(GatherPredsCallback, title_level=3)
show_doc(GatherPredsCallback.begin_validate)
GatherPredsCallback.begin_validate
[source]
GatherPredsCallback.begin_validate
()
Initialize containers
show_doc(GatherPredsCallback.after_batch)
GatherPredsCallback.after_batch
[source]
GatherPredsCallback.after_batch
()
Save predictions, targets and potentially losses
It happens that we may want to skip some of the steps of the training loop: in gradient accumulation, we don't aways want to do the step/zeroing of the grads for instance. During an LR finder test, we don't want to do the validation phase of an epoch. Or if we're training with a strategy of early stopping, we want to be able to completely interrupt the training loop.
This is made possible by raising specific exceptions the training loop will look for (and properly catch).
#export
_ex_docs = dict(
CancelFitException="Skip the rest of this batch and go to `after_batch`",
CancelEpochException="Skip the rest of the training part of the epoch and go to `after_train`",
CancelTrainException="Skip the rest of the validation part of the epoch and go to `after_validate`",
CancelValidException="Skip the rest of this epoch and go to `after_epoch`",
CancelBatchException="Interrupts training and go to `after_fit`")
for c,d in _ex_docs.items(): mk_class(c,sup=Exception,doc=d)
show_doc(CancelBatchException, title_level=3)
class
CancelBatchException
[source]
CancelBatchException
(***args
, **kwargs
**) ::Exception
Interrupts training and go to after_fit
show_doc(CancelTrainException, title_level=3)
class
CancelTrainException
[source]
CancelTrainException
(***args
, **kwargs
**) ::Exception
Skip the rest of the validation part of the epoch and go to after_validate
show_doc(CancelValidException, title_level=3)
class
CancelValidException
[source]
CancelValidException
(***args
, **kwargs
**) ::Exception
Skip the rest of this epoch and go to after_epoch
show_doc(CancelEpochException, title_level=3)
class
CancelEpochException
[source]
CancelEpochException
(***args
, **kwargs
**) ::Exception
Skip the rest of the training part of the epoch and go to after_train
show_doc(CancelFitException, title_level=3)
class
CancelFitException
[source]
CancelFitException
(***args
, **kwargs
**) ::Exception
Skip the rest of this batch and go to after_batch
You can detect one of those exceptions occurred and add code that executes right after with the following events:
after_cancel_batch
: reached imediately after a CancelBatchException
before proceeding to after_batch
after_cancel_train
: reached imediately after a CancelTrainException
before proceeding to after_epoch
after_cancel_valid
: reached imediately after a CancelValidException
before proceeding to after_epoch
after_cancel_epoch
: reached imediately after a CancelEpochException
before proceeding to after_epoch
after_cancel_fit
: reached imediately after a CancelFitException
before proceeding to after_fit
# export
_events = L.split('begin_fit begin_epoch begin_train begin_batch after_pred after_loss \
after_backward after_step after_cancel_batch after_batch after_cancel_train \
after_train begin_validate after_cancel_validate after_validate after_cancel_epoch \
after_epoch after_cancel_fit after_fit')
mk_class('event', **_events.map_dict(),
doc="All possible events as attributes to get tab-completion and typo-proofing")
_before_epoch = [event.begin_fit, event.begin_epoch]
_after_epoch = [event.after_epoch, event.after_fit]
# export
_all_ = ['event']
show_doc(event, name='event', title_level=3)
class
event
[source]
event
(***args
, **kwargs
**)
All possible events as attributes to get tab-completion and typo-proofing
test_eq(event.after_backward, 'after_backward')
Here's the full list: begin_fit begin_epoch begin_train begin_batch after_pred after_loss after_backward after_step after_cancel_batch after_batch after_cancel_train after_train begin_validate after_cancel_validate after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit.
#hide
#Full test of the control flow below, after the Learner class
# export
defaults.lr = slice(3e-3)
defaults.wd = 1e-2
defaults.callbacks = [TrainEvalCallback]
# export
def replacing_yield(o, attr, val):
"Context manager to temporarily replace an attribute"
old = getattr(o,attr)
try: yield setattr(o,attr,val)
finally: setattr(o,attr,old)
#export
def mk_metric(m):
"Convert `m` to an `AvgMetric`, unless it's already a `Metric`"
return m if isinstance(m, Metric) else AvgMetric(m)
#export
def save_model(file, model, opt, with_opt=True):
"Save `model` to `file` along with `opt` (if available, and if `with_opt`)"
if opt is None: with_opt=False
state = get_model(model).state_dict()
if with_opt: state = {'model': state, 'opt':opt.state_dict()}
torch.save(state, file)
# export
def load_model(file, model, opt, with_opt=None, device=None, strict=True):
"Load `model` from `file` along with `opt` (if available, and if `with_opt`)"
if isinstance(device, int): device = torch.device('cuda', device)
elif device is None: device = 'cpu'
state = torch.load(file, map_location=device)
hasopt = set(state)=={'model', 'opt'}
model_state = state['model'] if hasopt else state
get_model(model).load_state_dict(model_state, strict=strict)
if hasopt and ifnone(with_opt,True):
try: opt.load_state_dict(state['opt'])
except:
if with_opt: warn("Could not load the optimizer state.")
elif with_opt: warn("Saved filed doesn't contain an optimizer state.")
# export
def _try_concat(o):
try:
return torch.cat(o)
except:
return sum([L(o_[i,:] for i in range_of(o_)) for o_ in o], L())
# export
class Learner():
def __init__(self, dbunch, model, loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=trainable_params, cbs=None,
cb_funcs=None, metrics=None, path=None, model_dir='models', wd_bn_bias=False, train_bn=True):
store_attr(self, "dbunch,model,opt_func,lr,splitter,model_dir,wd_bn_bias,train_bn,metrics")
self.training,self.logger,self.opt,self.cbs = False,print,None,L()
#TODO: infer loss_func from data
if loss_func is None:
loss_func = getattr(dbunch.train_ds, 'loss_func', None)
assert loss_func is not None, "Could not infer loss function from the data, please pass a loss function."
self.loss_func = loss_func
self.path = path if path is not None else getattr(dbunch, 'path', Path('.'))
self.add_cbs(cbf() for cbf in L(defaults.callbacks)+L(cb_funcs))
self.add_cbs(cbs)
self.model.to(self.dbunch.device)
@property
def metrics(self): return self._metrics
@metrics.setter
def metrics(self,v): self._metrics = L(v).map(mk_metric)
def add_cbs(self, cbs): L(cbs).map(self.add_cb)
def remove_cbs(self, cbs): L(cbs).map(self.remove_cb)
def add_cb(self, cb):
old = getattr(self, cb.name, None)
assert not old or isinstance(old, type(cb)), f"self.{cb.name} already registered"
cb.learn = self
setattr(self, cb.name, cb)
self.cbs.append(cb)
return self
def remove_cb(self, cb):
cb.learn = None
if hasattr(self, cb.name): delattr(self, cb.name)
if cb in self.cbs: self.cbs.remove(cb)
@contextmanager
def added_cbs(self, cbs):
self.add_cbs(cbs)
yield
self.remove_cbs(cbs)
def ordered_cbs(self, cb_func:str): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, cb_func)]
def __call__(self, event_name): L(event_name).map(self._call_one)
def _call_one(self, event_name):
assert hasattr(event, event_name)
[cb(event_name) for cb in sort_by_run(self.cbs)]
def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)
def create_opt(self):
self.opt = self.opt_func(self.splitter(self.model), lr=self.lr)
if not self.wd_bn_bias:
for p in self._bn_bias_state(False): p['do_wd'] = False
if self.train_bn:
for p in self._bn_bias_state(True ): p['force_train'] = True
def _split(self, b):
i = getattr(self.dbunch, 'n_inp', 1 if len(b)==1 else len(b)-1)
self.xb,self.yb = b[:i],b[i:]
def all_batches(self):
self.n_iter = len(self.dl)
for o in enumerate(self.dl): self.one_batch(*o)
def one_batch(self, i, b):
self.iter = i
try:
self._split(b); self('begin_batch')
self.pred = self.model(*self.xb); self('after_pred')
if len(self.yb) == 0: return
self.loss = self.loss_func(self.pred, *self.yb); self('after_loss')
if not self.training: return
self.loss.backward(); self('after_backward')
self.opt.step(); self('after_step')
self.opt.zero_grad()
except CancelBatchException: self('after_cancel_batch')
finally: self('after_batch')
def _do_begin_fit(self, n_epoch):
self.n_epoch,self.loss = n_epoch,tensor(0.); self('begin_fit')
def _do_epoch_train(self):
try:
self.dl = self.dbunch.train_dl; self('begin_train')
self.all_batches()
except CancelTrainException: self('after_cancel_train')
finally: self('after_train')
def _do_epoch_validate(self, ds_idx=1, dl=None):
if dl is None: dl = self.dbunch.dls[ds_idx]
try:
self.dl = dl; self('begin_validate')
with torch.no_grad(): self.all_batches()
except CancelValidException: self('after_cancel_validate')
finally: self('after_validate')
def fit(self, n_epoch, lr=None, wd=defaults.wd, cbs=None, reset_opt=False):
with self.added_cbs(cbs):
if reset_opt or not self.opt: self.create_opt()
self.opt.set_hypers(wd=wd, lr=self.lr if lr is None else lr)
try:
self._do_begin_fit(n_epoch)
for epoch in range(n_epoch):
try:
self.epoch=epoch; self('begin_epoch')
self._do_epoch_train()
self._do_epoch_validate()
except CancelEpochException: self('after_cancel_epoch')
finally: self('after_epoch')
except CancelFitException: self('after_cancel_fit')
finally: self('after_fit')
def validate(self, ds_idx=1, dl=None, cbs=None):
self.epoch,self.n_epoch,self.loss = 0,1,tensor(0.)
if dl is None: dl = self.dbunch.dls[ds_idx]
with self.added_cbs(cbs), self.no_logging():
self(_before_epoch)
self._do_epoch_validate(ds_idx, dl)
self(_after_epoch)
return self.recorder.values[-1]
def get_preds(self, ds_idx=1, dl=None, with_input=False, with_loss=False, with_decoded=False, act=None):
self.epoch,self.n_epoch,self.loss = 0,1,tensor(0.)
cb = GatherPredsCallback(with_input=with_input, with_loss=with_loss)
with self.no_logging(), self.added_cbs(cb), self.loss_not_reduced():
self(_before_epoch)
self._do_epoch_validate(ds_idx, dl)
self(_after_epoch)
if act is None: act = getattr(self.loss_func, 'activation', noop)
preds = act(torch.cat(cb.preds))
res = (preds, detuplify(tuple(torch.cat(o) for o in zip(*cb.targets))))
if with_decoded: res = res + (getattr(self.loss_func, 'decodes', noop)(preds),)
if with_input: res = (tuple(_try_concat(o) for o in zip(*cb.inputs)),) + res
if with_loss: res = res + (torch.cat(cb.losses),)
return res
def predict(self, item, rm_type_tfms=0):
dl = test_dl(self.dbunch, [item], rm_type_tfms=rm_type_tfms)
inp,preds,_ = self.get_preds(dl=dl, with_input=True)
dec_preds = getattr(self.loss_func, 'decodes', noop)(preds)
i = getattr(self.dbunch, 'n_inp', -1)
full_dec = self.dbunch.decode_batch((*inp,dec_preds))[0][i:]
return detuplify(full_dec),dec_preds[0],preds[0]
def show_results(self, ds_idx=0, dl=None, max_n=10, **kwargs):
if dl is None: dl = self.dbunch.dls[ds_idx]
b = dl.one_batch()
_,_,preds = self.get_preds(dl=[b], with_decoded=True)
self.dbunch.show_results(b, preds, max_n=max_n, **kwargs)
def show_training_loop(self):
loop = ['Start Fit', 'begin_fit', 'Start Epoch Loop', 'begin_epoch', 'Start Train', 'begin_train',
'Start Batch Loop', 'begin_batch', 'after_pred', 'after_loss', 'after_backward',
'after_step', 'after_cancel_batch', 'after_batch','End Batch Loop','End Train',
'after_cancel_train', 'after_train', 'Start Valid', 'begin_validate','Start Batch Loop',
'**CBs same as train batch**', 'End Batch Loop', 'End Valid', 'after_cancel_validate',
'after_validate', 'End Epoch Loop', 'after_cancel_epoch', 'after_epoch', 'End Fit',
'after_cancel_fit', 'after_fit']
indent = 0
for s in loop:
if s.startswith('Start'): print(f'{" "*indent}{s}'); indent += 2
elif s.startswith('End'): indent -= 2; print(f'{" "*indent}{s}')
else: print(f'{" "*indent} - {s:15}:', self.ordered_cbs(s))
@contextmanager
def no_logging(self): return replacing_yield(self, 'logger', noop)
@contextmanager
def loss_not_reduced(self):
if hasattr(self.loss_func, 'reduction'): return replacing_yield(self.loss_func, 'reduction', 'none')
else: return replacing_yield(self, 'loss_func', partial(self.loss_func, reduction='none'))
def save(self, file, with_opt=True):
if rank_distrib(): return # don't save if slave proc
file = join_path_file(file, self.path/self.model_dir, ext='.pth')
save_model(file, self.model, getattr(self,'opt',None), with_opt)
def load(self, file, with_opt=None, device=None, strict=True):
if device is None: device = self.dbunch.device
if self.opt is None: self.create_opt()
file = join_path_file(file, self.path/self.model_dir, ext='.pth')
load_model(file, self.model, self.opt, with_opt=with_opt, device=device, strict=strict)
return self
Learner.x,Learner.y = add_props(lambda i,x: detuplify((x.xb,x.yb)[i]))
#export
add_docs(Learner, "Group together a `model`, some `dbunch` and a `loss_func` to handle training",
add_cbs="Add `cbs` to the list of `Callback` and register `self` as their learner",
add_cb="Add `cb` to the list of `Callback` and register `self` as their learner",
remove_cbs="Remove `cbs` from the list of `Callback` and deregister `self` as their learner",
remove_cb="Add `cb` from the list of `Callback` and deregister `self` as their learner",
added_cbs="Context manage that temporarily adds `cbs`",
ordered_cbs="Return a list of `Callback` for one step `cb_func` in the training loop",
create_opt="Create an optimizer with `lr`",
one_batch="Train or evaluate `self.model` on batch `(xb,yb)`",
all_batches="Train or evaluate `self.model` on all batches of `self.dl`",
fit="Fit `self.model` for `n_epoch` using `cbs`. Optionally `reset_opt`.",
validate="Validate on `dl` with potential new `cbs`.",
get_preds="Get the predictions and targets on the `ds_idx`-th dbunchset or `dl`, optionally `with_input` and `with_loss`",
predict="Return the prediction on `item`, fully decoded, loss function decoded and probabilities",
show_results="Show some predictions on `ds_idx`-th dbunchset or `dl`",
show_training_loop="Show each step in the training loop",
no_logging="Context manager to temporarily remove `logger`",
loss_not_reduced="A context manager to evaluate `loss_func` with reduction set to none.",
save="Save model and optimizer state (if `with_opt`) to `self.path/self.model_dir/file`",
load="Load model and optimizer state (if `with_opt`) from `self.path/self.model_dir/file` using `device`"
)
opt_func
will be used to create an optimizer when Learner.fit
is called, with lr
as a learning rate. splitter
is a function taht takes self.model
and returns a list of parameter groups (or just one parameter group if there are no different parameter groups). The default is trainable_params
, which returns all trainable parameters of the model.
cbs
is one or a list of Callback
s to pass to the Learner
, and cb_funcs
is one or a list of functions returning a Callback
that will be called at init. Each Callback
is registered as an attribute of Learner
(with camel case). At creation, all the callbacks in defaults.callbacks
(TrainEvalCallback
and Recorder
) are associated to the Learner
.
metrics
is an optional list of metrics, that can be either functions or Metric
s (see below).
#Test init with callbacks
def synth_learner(n_train=10, n_valid=2, cuda=False, lr=defaults.lr, **kwargs):
data = synth_dbunch(n_train=n_train,n_valid=n_valid, cuda=cuda)
return Learner(data, RegModel(), loss_func=MSELossFlat(), lr=lr, **kwargs)
tst_learn = synth_learner()
test_eq(len(tst_learn.cbs), 1)
assert isinstance(tst_learn.cbs[0], TrainEvalCallback)
assert hasattr(tst_learn, ('train_eval'))
tst_learn = synth_learner(cbs=TstCallback())
test_eq(len(tst_learn.cbs), 2)
assert isinstance(tst_learn.cbs[1], TstCallback)
assert hasattr(tst_learn, ('tst'))
tst_learn = synth_learner(cb_funcs=TstCallback)
test_eq(len(tst_learn.cbs), 2)
assert isinstance(tst_learn.cbs[1], TstCallback)
assert hasattr(tst_learn, ('tst'))
#A name that becomes an existing attribute of the Learner will throw an exception (here add_cb)
class AddCbCallback(Callback): pass
test_fail(lambda: synth_learner(cbs=AddCbCallback()))
show_doc(Learner.fit)
Learner.fit
[source]
Learner.fit
(n_epoch
,lr
=None
,wd
=0.01
,cbs
=None
,reset_opt
=False
)
Fit self.model
for n_epoch
using cbs
. Optionally reset_opt
.
#Training a few epochs should make the model better
learn = synth_learner(cb_funcs=TstCallback, lr=1e-2)
xb,yb = learn.dbunch.one_batch()
init_loss = learn.loss_func(learn.model(xb), yb)
learn.fit(2)
assert learn.loss < init_loss
#hide
#Test of TrainEvalCallback
class TestTrainEvalCallback(Callback):
run_after=TrainEvalCallback
def begin_fit(self):
test_eq([self.pct_train,self.train_iter], [0., 0])
self.old_pct_train,self.old_train_iter = self.pct_train,self.train_iter
def begin_batch(self): test_eq(next(self.model.parameters()).device, find_device(self.xb))
def after_batch(self):
if self.training:
test_eq(self.pct_train , self.old_pct_train+1/(self.n_iter*self.n_epoch))
test_eq(self.train_iter, self.old_train_iter+1)
self.old_pct_train,self.old_train_iter = self.pct_train,self.train_iter
def begin_train(self):
assert self.training and self.model.training
test_eq(self.pct_train, self.epoch/self.n_epoch)
self.old_pct_train = self.pct_train
def begin_validate(self):
assert not self.training and not self.model.training
learn = synth_learner(cb_funcs=TestTrainEvalCallback)
learn.fit(1)
#Check order is properly taken into account
learn.cbs = L(reversed(learn.cbs))
#hide
#cuda
#Check model is put on the GPU if needed
learn = synth_learner(cb_funcs=TestTrainEvalCallback, cuda=True)
learn.fit(1)
#hide
#Check wd is not applied on bn/bias when option wd_bn_bias=False
class _TstModel(nn.Module):
def __init__(self):
super().__init__()
self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1))
self.tst = nn.Sequential(nn.Linear(4,5), nn.BatchNorm1d(3))
self.tst[0].bias.data,self.tst[1].bias.data = torch.randn(5),torch.randn(3)
def forward(self, x): return x * self.a + self.b
class _PutGrad(Callback):
def after_backward(self):
for p in self.learn.model.tst.parameters():
p.grad = torch.ones_like(p.data)
learn = synth_learner(n_train=5, opt_func = partial(SGD, wd=1, decouple_wd=True), cb_funcs=_PutGrad)
learn.model = _TstModel()
init = [p.clone() for p in learn.model.tst.parameters()]
learn.fit(1, lr=1e-2)
end = list(learn.model.tst.parameters())
assert not torch.allclose(end[0]-init[0], -0.05 * torch.ones_like(end[0]))
for i in [1,2,3]: test_close(end[i]-init[i], -0.05 * torch.ones_like(end[i]))
show_doc(Learner.one_batch)
This is an internal method called by Learner.fit
. If passed, i
is the index of this iteration in the epoch. In training method, this does a full training step on the batch (compute predictions, loss, gradients, update the model parameters and zero the gradients). In validation mode, it stops at the loss computation.
# export
class VerboseCallback(Callback):
"Callback that prints the name of each event called"
def __call__(self, event_name):
print(event_name)
super().__call__(event_name)
#hide
class TestOneBatch(VerboseCallback):
def __init__(self, xb, yb, i):
self.save_xb,self.save_yb,self.i = xb,yb,i
self.old_pred,self.old_loss = None,tensor(0.)
def begin_batch(self):
self.old_a,self.old_b = self.model.a.data.clone(),self.model.b.data.clone()
test_eq(self.iter, self.i)
test_eq(self.save_xb, *self.xb)
test_eq(self.save_yb, *self.yb)
if hasattr(self.learn, 'pred'): test_eq(self.pred, self.old_pred)
def after_pred(self):
self.old_pred = self.pred
test_eq(self.pred, self.model.a.data * self.x + self.model.b.data)
test_eq(self.loss, self.old_loss)
def after_loss(self):
self.old_loss = self.loss
test_eq(self.loss, self.loss_func(self.old_pred, self.save_yb))
for p in self.model.parameters():
if not hasattr(p, 'grad') or p.grad is not None: test_eq(p.grad, tensor([0.]))
def after_backward(self):
self.grad_a = (2 * self.x * (self.pred.data - self.y)).mean()
self.grad_b = 2 * (self.pred.data - self.y).mean()
test_close(self.model.a.grad.data, self.grad_a)
test_close(self.model.b.grad.data, self.grad_b)
test_eq(self.model.a.data, self.old_a)
test_eq(self.model.b.data, self.old_b)
def after_step(self):
test_close(self.model.a.data, self.old_a - self.lr * self.grad_a)
test_close(self.model.b.data, self.old_b - self.lr * self.grad_b)
self.old_a,self.old_b = self.model.a.data.clone(),self.model.b.data.clone()
test_close(self.model.a.grad.data, self.grad_a)
test_close(self.model.b.grad.data, self.grad_b)
def after_batch(self):
for p in self.model.parameters(): test_eq(p.grad, tensor([0.]))
#hide
learn = synth_learner()
b = learn.dbunch.one_batch()
learn = synth_learner(cbs=TestOneBatch(*b, 42), lr=1e-2)
#Remove train/eval
learn.cbs = learn.cbs[1:]
#Setup
learn.loss,learn.training = tensor(0.),True
learn.opt = SGD(learn.model.parameters(), lr=learn.lr)
learn.model.train()
batch_events = ['begin_batch', 'after_pred', 'after_loss', 'after_backward', 'after_step', 'after_batch']
test_stdout(lambda: learn.one_batch(42, b), '\n'.join(batch_events))
test_stdout(lambda: learn.one_batch(42, b), '\n'.join(batch_events)) #Check it works for a second batch
show_doc(Learner.all_batches)
Learner.all_batches
[source]
Learner.all_batches
()
Train or evaluate self.model
on all batches of self.dl
#hide
learn = synth_learner(n_train=5, cbs=VerboseCallback())
learn.opt = SGD(learn.model.parameters(), lr=learn.lr)
with redirect_stdout(io.StringIO()):
learn._do_begin_fit(1)
learn.epoch,learn.dl = 0,learn.dbunch.train_dl
learn('begin_epoch')
learn('begin_train')
test_stdout(learn.all_batches, '\n'.join(batch_events * 5))
test_eq(learn.train_iter, 5)
valid_events = ['begin_batch', 'after_pred', 'after_loss', 'after_batch']
with redirect_stdout(io.StringIO()):
learn.dl = learn.dbunch.valid_dl
learn('begin_validate')
test_stdout(learn.all_batches, '\n'.join(valid_events * 2))
test_eq(learn.train_iter, 5)
#hide
learn = synth_learner(n_train=5, cbs=VerboseCallback())
test_stdout(lambda: learn._do_begin_fit(42), 'begin_fit')
test_eq(learn.n_epoch, 42)
test_eq(learn.loss, tensor(0.))
#hide
learn.opt = SGD(learn.model.parameters(), lr=learn.lr)
learn.epoch = 0
test_stdout(lambda: learn._do_epoch_train(), '\n'.join(['begin_train'] + batch_events * 5 + ['after_train']))
#hide
test_stdout(learn._do_epoch_validate, '\n'.join(['begin_validate'] + valid_events * 2+ ['after_validate']))
show_doc(Learner.save)
Learner.save
[source]
Learner.save
(file
,with_opt
=True
)
Save model and optimizer state (if with_opt
) to self.path/self.model_dir/file
file
can be a Path
, a string
or a buffer.
show_doc(Learner.load)
file
can be a Path
, a string
or a buffer. Use device
to load the model/optimizer state on a device different from the one it was saved.
learn = synth_learner(cb_funcs=TstCallback, opt_func=partial(SGD, mom=0.9))
xb,yb = learn.dbunch.one_batch()
init_loss = learn.loss_func(learn.model(xb), yb)
learn.fit(1)
learn.save('tmp')
assert (Path.cwd()/'models/tmp.pth').exists()
learn1 = synth_learner(cb_funcs=TstCallback, opt_func=partial(SGD, mom=0.9))
learn1 = learn1.load('tmp')
test_eq(learn.model.a, learn1.model.a)
test_eq(learn.model.b, learn1.model.b)
test_eq(learn.opt.state_dict(), learn1.opt.state_dict())
learn.save('tmp1', with_opt=False)
learn1 = synth_learner(cb_funcs=TstCallback, opt_func=partial(SGD, mom=0.9))
learn1 = learn1.load('tmp1')
test_eq(learn.model.a, learn1.model.a)
test_eq(learn.model.b, learn1.model.b)
test_ne(learn.opt.state_dict(), learn1.opt.state_dict())
shutil.rmtree('models')
show_doc(Learner.__call__)
show_doc(Learner.add_cb)
learn = synth_learner()
learn.add_cb(TestTrainEvalCallback())
test_eq(len(learn.cbs), 2)
assert isinstance(learn.cbs[1], TestTrainEvalCallback)
test_eq(learn.train_eval.learn, learn)
show_doc(Learner.add_cbs)
learn.add_cbs([TestTrainEvalCallback(), TestTrainEvalCallback()])
test_eq(len(learn.cbs), 4)
show_doc(Learner.remove_cb)
cb = learn.cbs[1]
learn.remove_cb(learn.cbs[1])
test_eq(len(learn.cbs), 3)
assert cb.learn is None
assert not getattr(learn,'test_train_eval',None)
show_doc(Learner.remove_cbs)
cb = learn.cbs[1]
learn.remove_cbs(learn.cbs[1:])
test_eq(len(learn.cbs), 1)
When writing a callback, the following attributes of Learner
are available:
model
: the model used for training/validationdata
: the underlying DataBunch
loss_func
: the loss function usedopt
: the optimizer used to udpate the model parametersopt_func
: the function used to create the optimizercbs
: the list containing all Callback
sdl
: current DataLoader
used for iterationx
/xb
: last input drawn from self.dl
(potentially modified by callbacks). xb
is always a tuple (potentially with one element) and x
is detuplified. You can only assign to xb
.y
/yb
: last target drawn from self.dl
(potentially modified by callbacks). yb
is always a tuple (potentially with one element) and y
is detuplified. You can only assign to yb
.pred
: last predictions from self.model
(potentially modified by callbacks)loss
: last computed loss (potentially modified by callbacks)n_epoch
: the number of epochs in this trainingn_iter
: the number of iterations in the current self.dl
epoch
: the current epoch index (from 0 to n_epoch-1
)iter
: the current iteration index in self.dl
(from 0 to n_iter-1
)The following attributes are added by TrainEvalCallback
and should be available unless you went out of your way to remove that callback:
train_iter
: the number of training iterations done since the beginning of this trainingpct_train
: from 0. to 1., the percentage of training iterations completedtraining
: flag to indicate if we're in training mode or notThe following attribute is added by Recorder
and should be available unless you went out of your way to remove that callback:
smooth_loss
: an exponentially-averaged version of the training loss#hide
batch_events = ['begin_batch', 'after_pred', 'after_loss', 'after_backward', 'after_step', 'after_batch']
batchv_events = ['begin_batch', 'after_pred', 'after_loss', 'after_batch']
train_events = ['begin_train'] + batch_events + ['after_train']
valid_events = ['begin_validate'] + batchv_events + ['after_validate']
epoch_events = ['begin_epoch'] + train_events + valid_events + ['after_epoch']
cycle_events = ['begin_fit'] + epoch_events + ['after_fit']
#hide
learn = synth_learner(n_train=1, n_valid=1)
test_stdout(lambda: learn.fit(1, cbs=VerboseCallback()), '\n'.join(cycle_events))
#hide
class TestCancelCallback(VerboseCallback):
def __init__(self, cancel_at=event.begin_batch, exception=CancelBatchException, train=None):
def _interrupt():
if train is None or train == self.training: raise exception()
setattr(self, cancel_at, _interrupt)
#hide
#test cancel batch
for i,e in enumerate(batch_events[:-1]):
be = batch_events[:i+1] + ['after_cancel_batch', 'after_batch']
bev = be if i <3 else batchv_events
cycle = cycle_events[:3] + be + ['after_train', 'begin_validate'] + bev + cycle_events[-3:]
test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(cancel_at=e)), '\n'.join(cycle))
#CancelBatchException not caught if thrown in any other event
for e in cycle_events:
if e not in batch_events[:-1]:
with redirect_stdout(io.StringIO()):
cb = TestCancelCallback(cancel_at=e)
test_fail(lambda: learn.fit(1, cbs=cb))
learn.remove_cb(cb) #Have to remove it manually
#hide
#test cancel train
for i,e in enumerate(['begin_train'] + batch_events):
be = batch_events[:i] + (['after_batch'] if i >=1 and i < len(batch_events) else [])
be += ['after_cancel_train', 'after_train']
cycle = cycle_events[:3] + be + ['begin_validate'] + batchv_events + cycle_events[-3:]
test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelTrainException, True)), '\n'.join(cycle))
#CancelTrainException not caught if thrown in any other event
for e in cycle_events:
if e not in ['begin_train'] + batch_events[:-1]:
with redirect_stdout(io.StringIO()):
cb = TestCancelCallback(e, CancelTrainException)
test_fail(lambda: learn.fit(1, cbs=cb))
learn.remove_cb(cb) #Have to remove it manually
#hide
#test cancel valid
for i,e in enumerate(['begin_validate'] + batchv_events):
bev = batchv_events[:i] + (['after_batch'] if i >=1 and i < len(batchv_events) else []) + ['after_cancel_validate']
cycle = cycle_events[:3] + batch_events + ['after_train', 'begin_validate'] + bev + cycle_events[-3:]
test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelValidException, False)), '\n'.join(cycle))
#CancelValidException not caught if thrown in any other event
for e in cycle_events:
if e not in ['begin_validate'] + batch_events[:3]:
with redirect_stdout(io.StringIO()):
cb = TestCancelCallback(e, CancelValidException)
test_fail(lambda: learn.fit(1, cbs=cb))
learn.remove_cb(cb) #Have to remove it manually
#hide
#test cancel epoch
#In train
for i,e in enumerate(['begin_train'] + batch_events):
be = batch_events[:i] + (['after_batch'] if i >=1 and i<len(batch_events) else [])
cycle = cycle_events[:3] + be + ['after_train', 'after_cancel_epoch'] + cycle_events[-2:]
test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelEpochException, True)), '\n'.join(cycle))
#In valid
for i,e in enumerate(['begin_validate'] + batchv_events):
bev = batchv_events[:i] + (['after_batch'] if i >=1 and i<len(batchv_events) else [])
cycle = cycle_events[:3] + batch_events + ['after_train', 'begin_validate'] + bev
cycle += ['after_validate', 'after_cancel_epoch'] + cycle_events[-2:]
test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelEpochException, False)), '\n'.join(cycle))
#In begin epoch
test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback('begin_epoch', CancelEpochException, False)),
'\n'.join(cycle_events[:2] + ['after_cancel_epoch'] + cycle_events[-2:]))
#CancelEpochException not caught if thrown in any other event
for e in ['begin_fit', 'after_epoch', 'after_fit']:
if e not in ['begin_validate'] + batch_events[:3]:
with redirect_stdout(io.StringIO()):
cb = TestCancelCallback(e, CancelEpochException)
test_fail(lambda: learn.fit(1, cbs=cb))
learn.remove_cb(cb) #Have to remove it manually
#hide
#test cancel fit
#In begin fit
test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback('begin_fit', CancelFitException)),
'\n'.join(['begin_fit', 'after_cancel_fit', 'after_fit']))
#In begin epoch
test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback('begin_epoch', CancelFitException, False)),
'\n'.join(cycle_events[:2] + ['after_epoch', 'after_cancel_fit', 'after_fit']))
#In train
for i,e in enumerate(['begin_train'] + batch_events):
be = batch_events[:i] + (['after_batch'] if i >=1 and i<len(batch_events) else [])
cycle = cycle_events[:3] + be + ['after_train', 'after_epoch', 'after_cancel_fit', 'after_fit']
test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelFitException, True)), '\n'.join(cycle))
#In valid
for i,e in enumerate(['begin_validate'] + batchv_events):
bev = batchv_events[:i] + (['after_batch'] if i >=1 and i<len(batchv_events) else [])
cycle = cycle_events[:3] + batch_events + ['after_train', 'begin_validate'] + bev
cycle += ['after_validate', 'after_epoch', 'after_cancel_fit', 'after_fit']
test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelFitException, False)), '\n'.join(cycle))
#CancelEpochException not caught if thrown in any other event
with redirect_stdout(io.StringIO()):
cb = TestCancelCallback('after_fit', CancelEpochException)
test_fail(lambda: learn.fit(1, cbs=cb))
learn.remove_cb(cb) #Have to remove it manually
#export
@docs
class Metric():
"Blueprint for defining a metric"
def reset(self): pass
def accumulate(self, learn): pass
@property
def value(self): raise NotImplementedError
@property
def name(self): return class2attr(self, 'Metric')
_docs = dict(
reset="Reset inner state to prepare for new computation",
name="Name of the `Metric`, camel-cased and with Metric removed",
accumulate="Use `learn` to update the state with new results",
value="The value of the metric")
show_doc(Metric, title_level=3)
Metrics can be simple averages (like accuracy) but sometimes their computation is a little bit more complex and can't be averaged over batches (like precision or recall), which is why we need a special class for them. For simple functions that can be computed as averages over batches, we can use the class AvgMetric
, otherwise you'll need to implement the following methods.
Note: If your
Metric
has state depending on tensors, don't forget to store it on the CPU to avoid any potential memory leaks.
show_doc(Metric.reset)
show_doc(Metric.accumulate)
show_doc(Metric.value, name='Metric.value')
Metric.value
[source]The value of the metric
show_doc(Metric.name, name='Metric.name')
#export
def _maybe_reduce(val):
if num_distrib()>1:
val = val.clone()
torch.distributed.all_reduce(val, op=torch.distributed.ReduceOp.SUM)
val /= num_distrib()
return val
#export
class AvgMetric(Metric):
"Average the values of `func` taking into account potential different batch sizes"
def __init__(self, func): self.func = func
def reset(self): self.total,self.count = 0.,0
def accumulate(self, learn):
bs = find_bs(learn.yb)
self.total += to_detach(_maybe_reduce(self.func(learn.pred, *learn.yb)))*bs
self.count += bs
@property
def value(self): return self.total/self.count if self.count != 0 else None
@property
def name(self): return self.func.__name__
show_doc(AvgMetric, title_level=3)
learn = synth_learner()
tst = AvgMetric(lambda x,y: (x-y).abs().mean())
t,u = torch.randn(100),torch.randn(100)
tst.reset()
for i in range(0,100,25):
learn.pred,learn.yb = t[i:i+25],(u[i:i+25],)
tst.accumulate(learn)
test_close(tst.value, (t-u).abs().mean())
#hide
#With varying batch size
tst.reset()
splits = [0, 30, 50, 60, 100]
for i in range(len(splits )-1):
learn.pred,learn.yb = t[splits[i]:splits[i+1]],(u[splits[i]:splits[i+1]],)
tst.accumulate(learn)
test_close(tst.value, (t-u).abs().mean())
#export
class AvgLoss(Metric):
"Average the losses taking into account potential different batch sizes"
def reset(self): self.total,self.count = 0.,0
def accumulate(self, learn):
bs = find_bs(learn.yb)
self.total += to_detach(_maybe_reduce(learn.loss.mean()))*bs
self.count += bs
@property
def value(self): return self.total/self.count if self.count != 0 else None
@property
def name(self): return "loss"
show_doc(AvgLoss, title_level=3)
tst = AvgLoss()
t = torch.randn(100)
tst.reset()
for i in range(0,100,25):
learn.yb,learn.loss = t[i:i+25],t[i:i+25].mean()
tst.accumulate(learn)
test_close(tst.value, t.mean())
#hide
#With varying batch size
tst.reset()
splits = [0, 30, 50, 60, 100]
for i in range(len(splits )-1):
learn.yb,learn.loss = t[splits[i]:splits[i+1]],t[splits[i]:splits[i+1]].mean()
tst.accumulate(learn)
test_close(tst.value, t.mean())
#export
class AvgSmoothLoss(Metric):
"Smooth average of the losses (exponentially weighted with `beta`)"
def __init__(self, beta=0.98): self.beta = beta
def reset(self): self.count,self.val = 0,tensor(0.)
def accumulate(self, learn):
self.count += 1
self.val = torch.lerp(to_detach(learn.loss.mean()), self.val, self.beta)
@property
def value(self): return self.val/(1-self.beta**self.count)
show_doc(AvgSmoothLoss, title_level=3)
tst = AvgSmoothLoss()
t = torch.randn(100)
tst.reset()
val = tensor(0.)
for i in range(4):
learn.loss = t[i*25:(i+1)*25].mean()
tst.accumulate(learn)
val = val*0.98 + t[i*25:(i+1)*25].mean()*(1-0.98)
test_close(val/(1-0.98**(i+1)), tst.value)
#export
from fastprogress.fastprogress import format_time
def _maybe_item(t):
t = t.value
return t.item() if isinstance(t, Tensor) and t.numel()==1 else t
#export
class Recorder(Callback):
"Callback that registers statistics (lr, loss and metrics) during training"
run_after = TrainEvalCallback
def __init__(self, add_time=True, train_metrics=False, beta=0.98):
self.add_time,self.train_metrics = add_time,train_metrics
self.loss,self.smooth_loss = AvgLoss(),AvgSmoothLoss(beta=beta)
def begin_fit(self):
"Prepare state for training"
self.lrs,self.iters,self.losses,self.values = [],[],[],[]
names = self._valid_mets.attrgot('name')
if self.train_metrics: names = names.map('train_{}') + names.map('valid_{}')
else: names = L('train_loss', 'valid_loss') + names[1:]
if self.add_time: names.append('time')
self.metric_names = 'epoch'+names
self.smooth_loss.reset()
def after_batch(self):
"Update all metrics and records lr and smooth loss in training"
if len(self.yb) == 0: return
mets = self._train_mets if self.training else self._valid_mets
for met in mets: met.accumulate(self.learn)
if not self.training: return
self.lrs.append(self.opt.hypers[-1]['lr'])
self.losses.append(self.smooth_loss.value)
self.learn.smooth_loss = self.smooth_loss.value
def begin_epoch(self):
"Set timer if `self.add_time=True`"
self.cancel_train,self.cancel_valid = False,False
if self.add_time: self.start_epoch = time.time()
self.log = L(getattr(self, 'epoch', 0))
def begin_train (self): self._train_mets[1:].map(Self.reset())
def begin_validate(self): self._valid_mets.map(Self.reset())
def after_train (self): self.log += self._train_mets.map(_maybe_item)
def after_validate(self): self.log += self._valid_mets.map(_maybe_item)
def after_cancel_train(self): self.cancel_train = True
def after_cancel_validate(self): self.cancel_valid = True
def after_epoch(self):
"Store and log the loss/metric values"
self.values.append(self.log[1:].copy())
if self.add_time: self.log.append(format_time(time.time() - self.start_epoch))
self.logger(self.log)
self.iters.append(self.smooth_loss.count)
@property
def _train_mets(self):
if getattr(self, 'cancel_train', False): return L()
return L(self.smooth_loss) + (self.metrics if self.train_metrics else L())
@property
def _valid_mets(self):
if getattr(self, 'cancel_valid', False): return L()
return L(self.loss) + self.metrics
def plot_loss(self, skip_start=5, with_valid=True):
plt.plot(self.losses[skip_start:], label='train')
if with_valid:
plt.plot(self.iters, L(self.values).itemgot(1), label='valid')
plt.legend()
#export
add_docs(Recorder,
begin_train = "Reset loss and metrics state",
after_train = "Log loss and metric values on the training set (if `self.training_metrics=True`)",
begin_validate = "Reset loss and metrics state",
after_validate = "Log loss and metric values on the validation set",
after_cancel_train = "Ignore training metrics for this epoch",
after_cancel_validate = "Ignore validation metrics for this epoch",
plot_loss = "Plot the losses from `skip_start` and onward")
defaults.callbacks = [TrainEvalCallback, Recorder]
By default, metrics are computed on the validation set only, although that can be changed with training_metrics=True
. beta
is the weight used to compute the exponentially weighted average of the losses (which gives the smooth_loss
attribute to Learner
).
#Test printed output
def tst_metric(out, targ): return F.mse_loss(out, targ)
learn = synth_learner(n_train=5, metrics=tst_metric)
pat = r"[tensor\(\d.\d*\), tensor\(\d.\d*\), tensor\(\d.\d*\), 'dd:dd']"
test_stdout(lambda: learn.fit(1), pat, regex=True)
#hide
class TestRecorderCallback(Callback):
run_after=Recorder
def begin_fit(self):
self.train_metrics,self.add_time = self.recorder.train_metrics,self.recorder.add_time
self.beta = self.recorder.smooth_loss.beta
for m in self.metrics: assert isinstance(m, Metric)
test_eq(self.recorder.smooth_loss.val, 0.)
#To test what the recorder logs, we use a custom logger function.
self.learn.logger = self.test_log
self.old_smooth,self.count = tensor(0.),0
def after_batch(self):
if self.training:
self.count += 1
test_eq(len(self.recorder.lrs), self.count)
test_eq(self.recorder.lrs[-1], self.opt.hypers[-1]['lr'])
test_eq(len(self.recorder.losses), self.count)
smooth = (1 - self.beta**(self.count-1)) * self.old_smooth * self.beta + self.loss * (1-self.beta)
smooth /= 1 - self.beta**self.count
test_close(self.recorder.losses[-1], smooth, eps=1e-4)
test_close(self.smooth_loss, smooth, eps=1e-4)
self.old_smooth = self.smooth_loss
self.bs += find_bs(self.yb)
if not self.training: test_eq(self.recorder.loss.count, self.bs)
if self.train_metrics or not self.training:
for m in self.metrics: test_eq(m.count, self.bs)
self.losses.append(self.loss.detach().cpu())
def begin_epoch(self):
if self.add_time: self.start_epoch = time.time()
self.log = [self.epoch]
def begin_train(self):
self.bs = 0
self.losses = []
for m in self.recorder._train_mets: test_eq(m.count, self.bs)
def after_train(self):
mean = tensor(self.losses).mean()
self.log += [self.smooth_loss, mean] if self.train_metrics else [self.smooth_loss]
test_eq(self.log, self.recorder.log)
self.losses = []
def begin_validate(self):
self.bs = 0
self.losses = []
for m in [self.recorder.loss] + self.metrics: test_eq(m.count, self.bs)
def test_log(self, log):
res = tensor(self.losses).mean()
self.log += [res, res]
if self.add_time: self.log.append(format_time(time.time() - self.start_epoch))
test_eq(log, self.log)
#hide
learn = synth_learner(n_train=5, metrics = tst_metric, cb_funcs = TestRecorderCallback)
learn.fit(1)
test_eq(learn.recorder.metric_names, ['epoch', 'train_loss', 'valid_loss', 'tst_metric', 'time'])
learn = synth_learner(n_train=5, metrics = tst_metric, cb_funcs = TestRecorderCallback)
learn.recorder.train_metrics=True
learn.fit(1)
test_eq(learn.recorder.metric_names,
['epoch', 'train_loss', 'train_tst_metric', 'valid_loss', 'valid_tst_metric', 'time'])
learn = synth_learner(n_train=5, metrics = tst_metric, cb_funcs = TestRecorderCallback)
learn.recorder.add_time=False
learn.fit(1)
test_eq(learn.recorder.metric_names, ['epoch', 'train_loss', 'valid_loss', 'tst_metric'])
#hide
#Test numpy metric
def tst_metric_np(out, targ): return F.mse_loss(out, targ).numpy()
learn = synth_learner(n_train=5, metrics=tst_metric_np)
learn.fit(1)
(#5) [0,10.249631881713867,9.148826599121094,9.148827075958252,00:00]
show_doc(Recorder.begin_fit)
show_doc(Recorder.begin_epoch)
show_doc(Recorder.begin_validate)
show_doc(Recorder.after_batch)
Recorder.after_batch
[source]
Recorder.after_batch
()
Update all metrics and records lr and smooth loss in training
show_doc(Recorder.after_epoch)
show_doc(Recorder.plot_loss)
Recorder.plot_loss
[source]
Recorder.plot_loss
(skip_start
=5
,with_valid
=True
)
Plot the losses from skip_start
and onward
#hide
learn.recorder.plot_loss(skip_start=1)
show_doc(Learner.no_logging)
learn = synth_learner(n_train=5, metrics=tst_metric)
with learn.no_logging():
test_stdout(lambda: learn.fit(1), '')
test_eq(learn.logger, print)
show_doc(Learner.validate)
Learner.validate
[source]
Learner.validate
(ds_idx
=1
,dl
=None
,cbs
=None
)
Validate on dl
with potential new cbs
.
#Test result
learn = synth_learner(n_train=5, metrics=tst_metric)
res = learn.validate()
test_eq(res[0], res[1])
x,y = learn.dbunch.valid_ds.tensors
test_close(res[0], F.mse_loss(learn.model(x), y))
#hide
#Test other dl
res = learn.validate(dl=learn.dbunch.train_dl)
test_eq(res[0], res[1])
x,y = learn.dbunch.train_ds.tensors
test_close(res[0], F.mse_loss(learn.model(x), y))
#Test additional callback is executed.
cycle = cycle_events[:2] + ['begin_validate'] + batchv_events * 2 + cycle_events[-3:]
test_stdout(lambda: learn.validate(cbs=VerboseCallback()), '\n'.join(cycle))
show_doc(Learner.loss_not_reduced)
Learner.loss_not_reduced
[source]
Learner.loss_not_reduced
()
A context manager to evaluate loss_func
with reduction set to none.
#hide
test_eq(learn.loss_func.reduction, 'mean')
with learn.loss_not_reduced():
test_eq(learn.loss_func.reduction, 'none')
x,y = learn.dbunch.one_batch()
p = learn.model(x)
losses = learn.loss_func(p, y)
test_eq(losses.shape, y.shape)
test_eq(losses, F.mse_loss(p,y, reduction='none'))
test_eq(learn.loss_func.reduction, 'mean')
show_doc(Learner.get_preds)
Learner.get_preds
[source]
Learner.get_preds
(ds_idx
=1
,dl
=None
,with_input
=False
,with_loss
=False
,with_decoded
=False
,act
=None
)
Get the predictions and targets on the ds_idx
-th dbunchset or dl
, optionally with_input
and with_loss
Depending on the loss_func
attribute of Learner
, an activation function will be picked automatically so that the predictions make sense. For instance if the loss is a case of cross-entropy, a softmax will be applied, or if the loss is binary cross entropy with logits, a sigmoid will be applied. If you want to make sure a certain activation function is applied, you can pass it with act
.
Note: If you want to use the option
with_loss=True
on a custom loss function, make sure you have implemented areduction
attribute that supports 'none'
#Test result
learn = synth_learner(n_train=5, metrics=tst_metric)
preds,targs = learn.get_preds()
x,y = learn.dbunch.valid_ds.tensors
test_eq(targs, y)
test_close(preds, learn.model(x))
preds,targs = learn.get_preds(act = torch.sigmoid)
test_eq(targs, y)
test_close(preds, torch.sigmoid(learn.model(x)))
#Test get_preds work with ds not evenly dividble by bs
learn = synth_learner(n_train=2.5, metrics=tst_metric)
preds,targs = learn.get_preds(ds_idx=0)
#hide
#Test other dataset
x = torch.randn(16*5)
y = 2*x + 3 + 0.1*torch.randn(16*5)
dl = TfmdDL(TensorDataset(x, y), bs=16)
preds,targs = learn.get_preds(dl=dl)
test_eq(targs, y)
test_close(preds, learn.model(x))
#Test with loss
preds,targs,losses = learn.get_preds(dl=dl, with_loss=True)
test_eq(targs, y)
test_close(preds, learn.model(x))
test_close(losses, F.mse_loss(preds, targs, reduction='none'))
#Test with inputs
inps,preds,targs = learn.get_preds(dl=dl, with_input=True)
test_eq(*inps,x)
test_eq(targs, y)
test_close(preds, learn.model(x))
#hide
#Test with no target
learn = synth_learner(n_train=5)
x = torch.randn(16*5)
dl = TfmdDL(TensorDataset(x), bs=16)
preds,targs = learn.get_preds(dl=dl)
assert targs is None
#hide
#Test with targets that are tuples
def _fake_loss(x,y,z,reduction=None): return F.mse_loss(x,y)
learn = synth_learner(n_train=5)
x = torch.randn(16*5)
y = 2*x + 3 + 0.1*torch.randn(16*5)
learn.dbunch.n_inp=1
learn.loss_func = _fake_loss
dl = TfmdDL(TensorDataset(x, y, y), bs=16)
preds,targs = learn.get_preds(dl=dl)
test_eq(targs, [y,y])
#hide
#Test with inputs that are tuples
class _TupleModel(Module):
def __init__(self, model): self.model=model
def forward(self, x1, x2): return self.model(x1)
learn = synth_learner(n_train=5)
#learn.dbunch.n_inp=2
x = torch.randn(16*5)
y = 2*x + 3 + 0.1*torch.randn(16*5)
learn.model = _TupleModel(learn.model)
learn.dbunch = DataBunch(TfmdDL(TensorDataset(x, x, y), bs=16),TfmdDL(TensorDataset(x, x, y), bs=16))
inps,preds,targs = learn.get_preds(ds_idx=0, with_input=True)
test_eq(inps, [x,x])
#hide
#Test auto activation function is picked
learn = synth_learner(n_train=5)
learn.loss_func = BCEWithLogitsLossFlat()
x = torch.randn(16*5)
y = 2*x + 3 + 0.1*torch.randn(16*5)
dl = TfmdDL(TensorDataset(x, y), bs=16)
preds,targs = learn.get_preds(dl=dl)
test_close(preds, torch.sigmoid(learn.model(x)))
show_doc(Learner.predict)
Learner.predict
[source]
Learner.predict
(item
,rm_type_tfms
=0
)
Return the prediction on item
, fully decoded, loss function decoded and probabilities
It returns a tuple of three elements with, in reverse order,
decodes
method from itDataSource
/DataBunch
class _FakeLossFunc(Module):
reduction = 'none'
def forward(self, x, y): return F.mse_loss(x,y)
def activation(self, x): return x+1
def decodes(self, x): return 2*x
class _Add1(Transform):
def encodes(self, x): return x+1
def decodes(self, x): return x-1
learn = synth_learner(n_train=5)
dl = TfmdDL(DataSource(torch.arange(50), tfms = [L(), [_Add1()]]))
learn.dbunch = DataBunch(dl, dl)
learn.loss_func = _FakeLossFunc()
inp = tensor([2.])
out = learn.model(inp).detach()+1 #applying model + activation
dec = 2*out #decodes from loss function
full_dec = dec-1 #decodes from _Add1
test_eq(learn.predict(tensor([2.])), [full_dec, dec, out])
#export
@patch
def freeze_to(self:Learner, n):
if self.opt is None: self.create_opt()
self.opt.freeze_to(n)
@patch
def freeze(self:Learner): self.freeze_to(-1)
@patch
def unfreeze(self:Learner): self.freeze_to(0)
add_docs(Learner,
freeze_to="Freeze parameter groups up to `n`",
freeze="Freeze up to last parameter group",
unfreeze="Unfreeze the entire model")
#hide
class _TstModel(nn.Module):
def __init__(self):
super().__init__()
self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1))
self.tst = nn.Sequential(nn.Linear(4,5), nn.BatchNorm1d(3))
self.tst[0].bias.data,self.tst[1].bias.data = torch.randn(5),torch.randn(3)
def forward(self, x): return x * self.a + self.b
class _PutGrad(Callback):
def after_backward(self):
for p in self.learn.model.tst.parameters():
if p.requires_grad: p.grad = torch.ones_like(p.data)
def _splitter(m): return [list(m.tst[0].parameters()), list(m.tst[1].parameters()), [m.a,m.b]]
learn = synth_learner(n_train=5, opt_func = partial(SGD), cb_funcs=_PutGrad, splitter=_splitter, lr=1e-2)
learn.model = _TstModel()
learn.freeze()
init = [p.clone() for p in learn.model.tst.parameters()]
learn.fit(1)
end = list(learn.model.tst.parameters())
#linear was not trained
for i in [0,1]: test_close(end[i],init[i])
#bn was trained even frozen since `train_bn=True` by default
for i in [2,3]: test_close(end[i]-init[i], -0.05 * torch.ones_like(end[i]))
(#4) [0,10.893106460571289,7.8781023025512695,00:00]
#hide
learn = synth_learner(n_train=5, opt_func = partial(SGD), cb_funcs=_PutGrad, splitter=_splitter, train_bn=False, lr=1e-2)
learn.model = _TstModel()
learn.freeze()
init = [p.clone() for p in learn.model.tst.parameters()]
learn.fit(1)
end = list(learn.model.tst.parameters())
#linear and bn were not trained
for i in range(4): test_close(end[i],init[i])
learn.freeze_to(-2)
init = [p.clone() for p in learn.model.tst.parameters()]
learn.fit(1)
end = list(learn.model.tst.parameters())
#linear was not trained
for i in [0,1]: test_close(end[i],init[i])
#bn was trained
for i in [2,3]: test_close(end[i]-init[i], -0.05 * torch.ones_like(end[i]))
learn.unfreeze()
init = [p.clone() for p in learn.model.tst.parameters()]
learn.fit(1)
end = list(learn.model.tst.parameters())
#linear and bn were trained
for i in range(4): test_close(end[i]-init[i], -0.05 * torch.ones_like(end[i]), 1e-3)
(#4) [0,30.35057258605957,27.175193786621094,00:00] (#4) [0,23.77756690979004,21.27766227722168,00:00] (#4) [0,18.555871963500977,16.66706085205078,00:00]
Learner
¶#export
@patch
def export(self:Learner, fname='export.pkl'):
"Export the content of `self` without the items and the optimizer state for inference"
if rank_distrib(): return # don't export if slave proc
old_dbunch = self.dbunch
self.dbunch = dbunch.new_empty()
state = self.opt.state_dict()
self.opt = None
with warnings.catch_warnings():
#To avoid the warning that come from PyTorch about model not being checked
warnings.simplefilter("ignore")
torch.save(self, open(self.path/fname, 'wb'))
self.create_opt()
self.opt.load_state_dict(state)
self.dbunch = old_dbunch
#hide
from local.notebook.export import notebook2script
notebook2script(all_fs=True)
Converted 00_test.ipynb. Converted 01_core.ipynb. Converted 01a_utils.ipynb. Converted 01b_dispatch.ipynb. Converted 01c_transform.ipynb. Converted 02_script.ipynb. Converted 03_torch_core.ipynb. Converted 03a_layers.ipynb. Converted 04_dataloader.ipynb. Converted 05_data_core.ipynb. Converted 06_data_transforms.ipynb. Converted 07_data_block.ipynb. Converted 08_vision_core.ipynb. Converted 09_vision_augment.ipynb. Converted 09a_vision_data.ipynb. Converted 10_pets_tutorial.ipynb. Converted 11_vision_models_xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_learner.ipynb. Converted 13a_metrics.ipynb. Converted 14_callback_schedule.ipynb. Converted 14a_callback_data.ipynb. Converted 15_callback_hook.ipynb. Converted 15a_vision_models_unet.ipynb. Converted 16_callback_progress.ipynb. Converted 17_callback_tracker.ipynb. Converted 18_callback_fp16.ipynb. Converted 19_callback_mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision_learner.ipynb. Converted 22_tutorial_imagenette.ipynb. Converted 23_tutorial_transfer_learning.ipynb. Converted 30_text_core.ipynb. Converted 31_text_data.ipynb. Converted 32_text_models_awdlstm.ipynb. Converted 33_text_models_core.ipynb. Converted 34_callback_rnn.ipynb. Converted 35_tutorial_wikitext.ipynb. Converted 36_text_models_qrnn.ipynb. Converted 37_text_learner.ipynb. Converted 38_tutorial_ulmfit.ipynb. Converted 40_tabular_core.ipynb. Converted 41_tabular_model.ipynb. Converted 42_tabular_rapids.ipynb. Converted 50_data_block_examples.ipynb. Converted 60_medical_imaging.ipynb. Converted 65_medical_text.ipynb. Converted 70_callback_wandb.ipynb. Converted 90_notebook_core.ipynb. Converted 91_notebook_export.ipynb. Converted 92_notebook_showdoc.ipynb. Converted 93_notebook_export2html.ipynb. Converted 94_notebook_test.ipynb. Converted 95_index.ipynb. Converted 96_data_external.ipynb. Converted 97_utils_test.ipynb. Converted notebook2jekyll.ipynb.