#hide
#skip
! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab
#default_exp callback.data
Callbacks which work with a learner's data
#export
from fastai.basics import *
#hide
from nbdev.showdoc import *
from fastai.test_utils import *
#export
class CollectDataCallback(Callback):
"Collect all batches, along with `pred` and `loss`, into `self.data`. Mainly for testing"
def before_fit(self): self.data = L()
def after_batch(self):
self.data.append(self.learn.to_detach((self.xb,self.yb,self.pred,self.loss)))
# export
class CudaCallback(Callback):
"Move data to CUDA device"
def __init__(self, device=None): self.device = ifnone(device, default_device())
def before_batch(self): self.learn.xb,self.learn.yb = to_device(self.xb),to_device(self.yb)
def before_fit(self): self.model.to(self.device)
You don't normally need to use this Callback, because fastai's DataLoader
will handle passing data to a device for you. However, if you already have a plain PyTorch DataLoader and can't change it for some reason, you can use this transform.
#cuda
learn = synth_learner(cbs=CudaCallback)
learn.model
learn.fit(1)
test_eq(next(learn.model.parameters()).device.type, 'cuda')
[0, 6.35821008682251, 4.982691287994385, '00:00']
#export
@delegates()
class WeightedDL(TfmdDL):
def __init__(self, dataset=None, bs=None, wgts=None, **kwargs):
super().__init__(dataset=dataset, bs=bs, **kwargs)
wgts = array([1.]*len(dataset) if wgts is None else wgts)
self.wgts = wgts/wgts.sum()
def get_idxs(self):
if self.n==0: return []
if not self.shuffle: return super().get_idxs()
return list(np.random.choice(self.n, self.n, p=self.wgts))
#export
@patch
@delegates(Datasets.dataloaders)
def weighted_dataloaders(self:Datasets, wgts, bs=64, **kwargs):
xtra_kwargs = [{}] * (self.n_subsets-1)
return self.dataloaders(bs=bs, dl_type=WeightedDL, dl_kwargs=({'wgts':wgts}, *xtra_kwargs), **kwargs)
n = 160
dsets = Datasets(torch.arange(n).float())
dls = dsets.weighted_dataloaders(wgts=range(n), bs=16)
learn = synth_learner(data=dls, cbs=CollectDataCallback)
learn.fit(1)
t = concat(*learn.collect_data.data.itemgot(0,0))
plt.hist(t.numpy());
[0, nan, None, '00:01']
#export
@delegates()
class PartialDL(TfmdDL):
"Select randomly partial quantity of data at each epoch"
def __init__(self, dataset=None, bs=None, partial_n=None, **kwargs):
super().__init__(dataset=dataset, bs=bs, **kwargs)
self.partial_n = min(partial_n, self.n) if partial_n else None
def get_idxs(self):
if self.partial_n is None: return super().get_idxs()
return list(np.random.choice(self.n, self.partial_n, replace=False))
def __len__(self):
if self.partial_n is None: return super().__len__()
return self.partial_n//self.bs + (0 if self.drop_last or self.partial_n%self.bs==0 else 1)
#export
@patch
@delegates(Datasets.dataloaders)
def partial_dataloaders(self:FilteredBase, partial_n, bs=64, **kwargs):
"Create a partial dataloader `PartialDL` for the training set"
xtra_kwargs = [{}] * (self.n_subsets-1)
return self.dataloaders(bs=bs, dl_type=PartialDL, dl_kwargs=({'partial_n':partial_n}, *xtra_kwargs), **kwargs)
dls = dsets.partial_dataloaders(partial_n=32, bs=16)
assert len(dls[0])==2
for batch in dls[0]:
assert len(batch[0])==16
#hide
from nbdev.export import notebook2script
notebook2script()
Converted 00_torch_core.ipynb. Converted 01_layers.ipynb. Converted 01a_losses.ipynb. Converted 02_data.load.ipynb. Converted 03_data.core.ipynb. Converted 04_data.external.ipynb. Converted 05_data.transforms.ipynb. Converted 06_data.block.ipynb. Converted 07_vision.core.ipynb. Converted 08_vision.data.ipynb. Converted 09_vision.augment.ipynb. Converted 09b_vision.utils.ipynb. Converted 09c_vision.widgets.ipynb. Converted 10_tutorial.pets.ipynb. Converted 10b_tutorial.albumentations.ipynb. Converted 11_vision.models.xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_callback.core.ipynb. Converted 13a_learner.ipynb. Converted 13b_metrics.ipynb. Converted 14_callback.schedule.ipynb. Converted 14a_callback.data.ipynb. Converted 15_callback.hook.ipynb. Converted 15a_vision.models.unet.ipynb. Converted 16_callback.progress.ipynb. Converted 17_callback.tracker.ipynb. Converted 18_callback.fp16.ipynb. Converted 18a_callback.training.ipynb. Converted 18b_callback.preds.ipynb. Converted 19_callback.mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision.learner.ipynb. Converted 22_tutorial.imagenette.ipynb. Converted 23_tutorial.vision.ipynb. Converted 24_tutorial.siamese.ipynb. Converted 24_vision.gan.ipynb. Converted 30_text.core.ipynb. Converted 31_text.data.ipynb. Converted 32_text.models.awdlstm.ipynb. Converted 33_text.models.core.ipynb. Converted 34_callback.rnn.ipynb. Converted 35_tutorial.wikitext.ipynb. Converted 36_text.models.qrnn.ipynb. Converted 37_text.learner.ipynb. Converted 38_tutorial.text.ipynb. Converted 39_tutorial.transformers.ipynb. Converted 40_tabular.core.ipynb. Converted 41_tabular.data.ipynb. Converted 42_tabular.model.ipynb. Converted 43_tabular.learner.ipynb. Converted 44_tutorial.tabular.ipynb. Converted 45_collab.ipynb. Converted 46_tutorial.collab.ipynb. Converted 50_tutorial.datablock.ipynb. Converted 60_medical.imaging.ipynb. Converted 61_tutorial.medical_imaging.ipynb. Converted 65_medical.text.ipynb. Converted 70_callback.wandb.ipynb. Converted 71_callback.tensorboard.ipynb. Converted 72_callback.neptune.ipynb. Converted 73_callback.captum.ipynb. Converted 74_callback.cutmix.ipynb. Converted 97_test_utils.ipynb. Converted 99_pytorch_doc.ipynb. Converted dev-setup.ipynb. Converted index.ipynb. Converted quick_start.ipynb. Converted tutorial.ipynb.