#default_exp callback.mixup
#export
from local.test import *
from local.basics import *
from local.callback.progress import *
from local.vision.core import *
from torch.distributions.beta import Beta
from local.notebook.showdoc import *
from local.test_utils import *
Callback to apply MixUp data augmentation to your training
# export
def reduce_loss(loss, reduction='mean'):
return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss
# export
class MixUp(Callback):
run_after=[Normalize, Cuda]
def __init__(self, alpha=0.4): self.distrib = Beta(tensor(alpha), tensor(alpha))
def begin_fit(self):
self.stack_y = getattr(self.learn.loss_func, 'y_int', False)
if self.stack_y: self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf
def after_fit(self):
if self.stack_y: self.learn.loss_func = self.old_lf
def begin_batch(self):
if not self.training: return
lam = self.distrib.sample((self.y.size(0),)).squeeze().to(self.x.device)
lam = torch.stack([lam, 1-lam], 1)
self.lam = lam.max(1)[0]
shuffle = torch.randperm(self.y.size(0)).to(self.x.device)
xb1,self.yb1 = tuple(L(self.xb).itemgot(shuffle)),tuple(L(self.yb).itemgot(shuffle))
nx_dims = len(self.x.size())
self.learn.xb = tuple(L(xb1,self.xb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nx_dims-1)))
if not self.stack_y:
ny_dims = len(self.y.size())
self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))
def lf(self, pred, *yb):
if not self.training: return self.old_lf(pred, *yb)
with NoneReduce(self.old_lf) as lf:
loss = torch.lerp(lf(pred,*self.yb1), lf(pred,*yb), self.lam)
return reduce_loss(loss, getattr(self.old_lf, 'reduction', 'mean'))
from local.vision.core import *
path = untar_data(URLs.MNIST_TINY)
items = get_image_files(path)
tds = DataSource(items, [PILImageBW.create, [parent_label, Categorize()]], splits=GrandparentSplitter()(items))
dbunch = tds.databunch(after_item=[ToTensor(), IntToFloatTensor()])
mixup = MixUp(0.5)
learn = Learner(dbunch, nn.Linear(3,4), loss_func=CrossEntropyLossFlat(), cbs=mixup)
learn._do_begin_fit(1)
learn.epoch,learn.training = 0,True
learn.dl = dbunch.train_dl
b = dbunch.one_batch()
learn._split(b)
learn('begin_batch')
_,axs = plt.subplots(3,3, figsize=(9,9))
dbunch.show_batch(b=(mixup.x,mixup.y), ctxs=axs.flatten())
#hide
from local.notebook.export import notebook2script
notebook2script(all_fs=True)
Converted 00_test.ipynb. Converted 01_core.ipynb. Converted 01a_utils.ipynb. Converted 01b_dispatch.ipynb. Converted 01c_transform.ipynb. Converted 02_script.ipynb. Converted 03_torch_core.ipynb. Converted 03a_layers.ipynb. Converted 04_dataloader.ipynb. Converted 05_data_core.ipynb. Converted 06_data_transforms.ipynb. Converted 07_data_block.ipynb. Converted 08_vision_core.ipynb. Converted 09_vision_augment.ipynb. Converted 09a_vision_data.ipynb. Converted 10_pets_tutorial.ipynb. Converted 11_vision_models_xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_learner.ipynb. Converted 13a_metrics.ipynb. Converted 14_callback_schedule.ipynb. Converted 14a_callback_data.ipynb. Converted 15_callback_hook.ipynb. Converted 15a_vision_models_unet.ipynb. Converted 16_callback_progress.ipynb. Converted 17_callback_tracker.ipynb. Converted 18_callback_fp16.ipynb. Converted 19_callback_mixup.ipynb. Converted 20_interpret.ipynb. Converted 21_vision_learner.ipynb. Converted 22_tutorial_imagenette.ipynb. Converted 23_tutorial_transfer_learning.ipynb. Converted 30_text_core.ipynb. Converted 31_text_data.ipynb. Converted 32_text_models_awdlstm.ipynb. Converted 33_text_models_core.ipynb. Converted 34_callback_rnn.ipynb. Converted 35_tutorial_wikitext.ipynb. Converted 36_text_models_qrnn.ipynb. Converted 37_text_learner.ipynb. Converted 38_tutorial_ulmfit.ipynb. Converted 40_tabular_core.ipynb. Converted 41_tabular_model.ipynb. Converted 42_tabular_rapids.ipynb. Converted 50_data_block_examples.ipynb. Converted 60_medical_imaging.ipynb. Converted 65_medical_text.ipynb. Converted 90_notebook_core.ipynb. Converted 91_notebook_export.ipynb. Converted 92_notebook_showdoc.ipynb. Converted 93_notebook_export2html.ipynb. Converted 94_notebook_test.ipynb. Converted 95_index.ipynb. Converted 96_data_external.ipynb. Converted 97_utils_test.ipynb. Converted notebook2jekyll.ipynb.