#|hide
#| eval: false
! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab
#|default_exp data.block
#|export
from __future__ import annotations
from fastai.torch_basics import *
from fastai.data.core import *
from fastai.data.load import *
from fastai.data.external import *
from fastai.data.transforms import *
#|hide
from nbdev.showdoc import *
High level API to quickly get your data in a
DataLoaders
#|export
class TransformBlock():
"A basic wrapper that links defaults transforms for the data block API"
def __init__(self,
type_tfms:list=None, # One or more `Transform`s
item_tfms:list=None, # `ItemTransform`s, applied on an item
batch_tfms:list=None, # `Transform`s or `RandTransform`s, applied by batch
dl_type:TfmdDL=None, # Task specific `TfmdDL`, defaults to `TfmdDL`
dls_kwargs:dict=None, # Additional arguments to be passed to `DataLoaders`
):
self.type_tfms = L(type_tfms)
self.item_tfms = ToTensor + L(item_tfms)
self.batch_tfms = L(batch_tfms)
self.dl_type,self.dls_kwargs = dl_type,({} if dls_kwargs is None else dls_kwargs)
#|export
def CategoryBlock(
vocab:MutableSequence|pd.Series=None, # List of unique class names
sort:bool=True, # Sort the classes alphabetically
add_na:bool=False, # Add `#na#` to `vocab`
):
"`TransformBlock` for single-label categorical targets"
return TransformBlock(type_tfms=Categorize(vocab=vocab, sort=sort, add_na=add_na))
#|export
def MultiCategoryBlock(
encoded:bool=False, # Whether the data comes in one-hot encoded
vocab:MutableSequence|pd.Series=None, # List of unique class names
add_na:bool=False, # Add `#na#` to `vocab`
):
"`TransformBlock` for multi-label categorical targets"
tfm = EncodedMultiCategorize(vocab=vocab) if encoded else [MultiCategorize(vocab=vocab, add_na=add_na), OneHotEncode]
return TransformBlock(type_tfms=tfm)
#|export
def RegressionBlock(
n_out:int=None, # Number of output values
):
"`TransformBlock` for float targets"
return TransformBlock(type_tfms=RegressionSetup(c=n_out))
#|export
from inspect import isfunction,ismethod
#|export
def _merge_grouper(o):
if isinstance(o, LambdaType): return id(o)
elif isinstance(o, type): return o
elif (isfunction(o) or ismethod(o)): return o.__qualname__
return o.__class__
#|export
def _merge_tfms(*tfms):
"Group the `tfms` in a single list, removing duplicates (from the same class) and instantiating"
g = groupby(concat(*tfms), _merge_grouper)
return L(v[-1] for k,v in g.items()).map(instantiate)
def _zip(x): return L(x).zip()
#For example, so not exported
from fastai.vision.core import *
from fastai.vision.data import *
#|hide
tfms = _merge_tfms([Categorize, MultiCategorize, Categorize(['dog', 'cat'])], Categorize(['a', 'b']))
#If there are several instantiated versions, the last one is kept.
test_eq(len(tfms), 2)
test_eq(tfms[1].__class__, MultiCategorize)
test_eq(tfms[0].__class__, Categorize)
test_eq(tfms[0].vocab, ['a', 'b'])
tfms = _merge_tfms([PILImage.create, PILImage.show])
#Check methods are properly separated
test_eq(len(tfms), 2)
tfms = _merge_tfms([show_image, set_trace])
#Check functions are properly separated
test_eq(len(tfms), 2)
_f = lambda x: 0
test_eq(len(_merge_tfms([_f,lambda x: 1])), 2)
test_eq(len(_merge_tfms([_f,_f])), 1)
#|export
@docs
@funcs_kwargs
class DataBlock():
"Generic container to quickly build `Datasets` and `DataLoaders`."
get_x=get_items=splitter=get_y = None
blocks,dl_type = (TransformBlock,TransformBlock),TfmdDL
_methods = 'get_items splitter get_y get_x'.split()
_msg = "If you wanted to compose several transforms in your getter don't forget to wrap them in a `Pipeline`."
def __init__(self,
blocks:list=None, # One or more `TransformBlock`s
dl_type:TfmdDL=None, # Task specific `TfmdDL`, defaults to `block`'s dl_type or`TfmdDL`
getters:list=None, # Getter functions applied to results of `get_items`
n_inp:int=None, # Number of inputs
item_tfms:list=None, # `ItemTransform`s, applied on an item
batch_tfms:list=None, # `Transform`s or `RandTransform`s, applied by batch
**kwargs,
):
blocks = L(self.blocks if blocks is None else blocks)
blocks = L(b() if callable(b) else b for b in blocks)
self.type_tfms = blocks.attrgot('type_tfms', L())
self.default_item_tfms = _merge_tfms(*blocks.attrgot('item_tfms', L()))
self.default_batch_tfms = _merge_tfms(*blocks.attrgot('batch_tfms', L()))
for b in blocks:
if getattr(b, 'dl_type', None) is not None: self.dl_type = b.dl_type
if dl_type is not None: self.dl_type = dl_type
self.dataloaders = delegates(self.dl_type.__init__)(self.dataloaders)
self.dls_kwargs = merge(*blocks.attrgot('dls_kwargs', {}))
self.n_inp = ifnone(n_inp, max(1, len(blocks)-1))
self.getters = ifnone(getters, [noop]*len(self.type_tfms))
if self.get_x:
if len(L(self.get_x)) != self.n_inp:
raise ValueError(f'get_x contains {len(L(self.get_x))} functions, but must contain {self.n_inp} (one for each input)\n{self._msg}')
self.getters[:self.n_inp] = L(self.get_x)
if self.get_y:
n_targs = len(self.getters) - self.n_inp
if len(L(self.get_y)) != n_targs:
raise ValueError(f'get_y contains {len(L(self.get_y))} functions, but must contain {n_targs} (one for each target)\n{self._msg}')
self.getters[self.n_inp:] = L(self.get_y)
if kwargs: raise TypeError(f'invalid keyword arguments: {", ".join(kwargs.keys())}')
self.new(item_tfms, batch_tfms)
def _combine_type_tfms(self): return L([self.getters, self.type_tfms]).map_zip(
lambda g,tt: (g.fs if isinstance(g, Pipeline) else L(g)) + tt)
def new(self,
item_tfms:list=None, # `ItemTransform`s, applied on an item
batch_tfms:list=None, # `Transform`s or `RandTransform`s, applied by batch
):
self.item_tfms = _merge_tfms(self.default_item_tfms, item_tfms)
self.batch_tfms = _merge_tfms(self.default_batch_tfms, batch_tfms)
return self
@classmethod
def from_columns(cls,
blocks:list =None, # One or more `TransformBlock`s
getters:list =None, # Getter functions applied to results of `get_items`
get_items:callable=None, # A function to get items
**kwargs,
):
if getters is None: getters = L(ItemGetter(i) for i in range(2 if blocks is None else len(L(blocks))))
get_items = _zip if get_items is None else compose(get_items, _zip)
return cls(blocks=blocks, getters=getters, get_items=get_items, **kwargs)
def datasets(self,
source, # The data source
verbose:bool=False, # Show verbose messages
) -> Datasets:
self.source = source ; pv(f"Collecting items from {source}", verbose)
items = (self.get_items or noop)(source) ; pv(f"Found {len(items)} items", verbose)
splits = (self.splitter or RandomSplitter())(items)
pv(f"{len(splits)} datasets of sizes {','.join([str(len(s)) for s in splits])}", verbose)
return Datasets(items, tfms=self._combine_type_tfms(), splits=splits, dl_type=self.dl_type, n_inp=self.n_inp, verbose=verbose)
def dataloaders(self,
source, # The data source
path:str='.', # Data source and default `Learner` path
verbose:bool=False, # Show verbose messages
**kwargs
) -> DataLoaders:
dsets = self.datasets(source, verbose=verbose)
kwargs = {**self.dls_kwargs, **kwargs, 'verbose': verbose}
return dsets.dataloaders(path=path, after_item=self.item_tfms, after_batch=self.batch_tfms, **kwargs)
_docs = dict(new="Create a new `DataBlock` with other `item_tfms` and `batch_tfms`",
datasets="Create a `Datasets` object from `source`",
dataloaders="Create a `DataLoaders` object from `source`")
To build a DataBlock
you need to give the library four things: the types of your input/labels, and at least two functions: get_items
and splitter
. You may also need to include get_x
and get_y
or a more generic list of getters
that are applied to the results of get_items
.
splitter is a callable which, when called with items
, returns a tuple of iterables representing the indices of the training and validation data.
Once those are provided, you automatically get a Datasets
or a DataLoaders
:
show_doc(DataBlock.datasets)
#| echo: false
dblock = DataBlock()
show_doc(dblock.dataloaders, name="DataBlock.dataloaders")
DataBlock.dataloaders
[source]
DataBlock.dataloaders
(source
,path
:str
='.'
,verbose
:bool
=False
,bs
=64
,shuffle
=False
,num_workers
=None
,do_setup
=True
,pin_memory
=False
,timeout
=0
,batch_size
=None
,drop_last
=False
,indexed
=None
,n
=None
,device
=None
,persistent_workers
=False
,wif
=None
,before_iter
=None
,after_item
=None
,before_batch
=None
,after_batch
=None
,after_iter
=None
,create_batches
=None
,create_item
=None
,create_batch
=None
,retain
=None
,get_idxs
=None
,sample
=None
,shuffle_fn
=None
,do_batch
=None
)
Create a DataLoaders
object from source
Type | Default | Details | |
---|---|---|---|
source |
The data source | ||
path |
str |
`` | Data source and default Learner path |
verbose |
bool |
False |
Show verbose messages |
Valid Keyword Arguments | |||
bs |
int |
64 |
Argument passed to TfmdDL.__init__ |
shuffle |
bool |
False |
Argument passed to TfmdDL.__init__ |
num_workers |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
do_setup |
bool |
True |
Argument passed to TfmdDL.__init__ |
pin_memory |
bool |
False |
Argument passed to TfmdDL.__init__ |
timeout |
int |
0 |
Argument passed to TfmdDL.__init__ |
batch_size |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
drop_last |
bool |
False |
Argument passed to TfmdDL.__init__ |
indexed |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
n |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
device |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
persistent_workers |
bool |
False |
Argument passed to TfmdDL.__init__ |
wif |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
before_iter |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
after_item |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
before_batch |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
after_batch |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
after_iter |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
create_batches |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
create_item |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
create_batch |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
retain |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
get_idxs |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
sample |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
shuffle_fn |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
do_batch |
NoneType |
`` | Argument passed to TfmdDL.__init__ |
Returns | DataLoaders |
You can create a DataBlock
by passing functions:
mnist = DataBlock(blocks = (ImageBlock(cls=PILImageBW),CategoryBlock),
get_items = get_image_files,
splitter = GrandparentSplitter(),
get_y = parent_label)
Each type comes with default transforms that will be applied:
They are called respectively type transforms, item transforms, batch transforms. In the case of MNIST, the type transforms are the method to create a PILImageBW
(for the input) and the Categorize
transform (for the target), the item transform is ToTensor
and the batch transforms are Cuda
and IntToFloatTensor
. You can add any other transforms by passing them in DataBlock.datasets
or DataBlock.dataloaders
.
test_eq(mnist.type_tfms[0], [PILImageBW.create])
test_eq(mnist.type_tfms[1].map(type), [Categorize])
test_eq(mnist.default_item_tfms.map(type), [ToTensor])
test_eq(mnist.default_batch_tfms.map(type), [IntToFloatTensor])
dsets = mnist.datasets(untar_data(URLs.MNIST_TINY))
test_eq(dsets.vocab, ['3', '7'])
x,y = dsets.train[0]
test_eq(x.size,(28,28))
show_at(dsets.train, 0, cmap='Greys', figsize=(2,2));
test_fail(lambda: DataBlock(wrong_kwarg=42, wrong_kwarg2='foo'))
We can pass any number of blocks to DataBlock
, we can then define what are the input and target blocks by changing n_inp
. For example, defining n_inp=2
will consider the first two blocks passed as inputs and the others as targets.
mnist = DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),
get_y=parent_label)
dsets = mnist.datasets(untar_data(URLs.MNIST_TINY))
test_eq(mnist.n_inp, 2)
test_eq(len(dsets.train[0]), 3)
test_fail(lambda: DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),
get_y=[parent_label, noop],
n_inp=2), msg='get_y contains 2 functions, but must contain 1 (one for each output)')
mnist = DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),
n_inp=1,
get_y=[noop, Pipeline([noop, parent_label])])
dsets = mnist.datasets(untar_data(URLs.MNIST_TINY))
test_eq(len(dsets.train[0]), 3)
#|export
def _short_repr(x):
if isinstance(x, tuple): return f'({", ".join([_short_repr(y) for y in x])})'
if isinstance(x, list): return f'[{", ".join([_short_repr(y) for y in x])}]'
if not isinstance(x, Tensor): return str(x)
if x.numel() <= 20 and x.ndim <=1: return str(x)
return f'{x.__class__.__name__} of size {"x".join([str(d) for d in x.shape])}'
#|hide
test_eq(_short_repr(TensorImage(torch.randn(40,56))), 'TensorImage of size 40x56')
test_eq(_short_repr(TensorCategory([1,2,3])), 'TensorCategory([1, 2, 3])')
test_eq(_short_repr((TensorImage(torch.randn(40,56)), TensorImage(torch.randn(32,20)))),
'(TensorImage of size 40x56, TensorImage of size 32x20)')
#|export
def _apply_pipeline(p, x):
print(f" {p}\n starting from\n {_short_repr(x)}")
for f in p.fs:
name = f.name
try:
x = f(x)
if name != "noop": print(f" applying {name} gives\n {_short_repr(x)}")
except Exception as e:
print(f" applying {name} failed.")
raise e
return x
#|export
from fastai.data.load import _collate_types
def _find_fail_collate(s):
s = L(*s)
for x in s[0]:
if not isinstance(x, _collate_types): return f"{type(x).__name__} is not collatable"
for i in range_of(s[0]):
try: _ = default_collate(s.itemgot(i))
except:
shapes = [getattr(o[i], 'shape', None) for o in s]
return f"Could not collate the {i}-th members of your tuples because got the following shapes\n{','.join([str(s) for s in shapes])}"
#|export
@patch
def summary(self:DataBlock,
source, # The data source
bs:int=4, # The batch size
show_batch:bool=False, # Call `show_batch` after the summary
**kwargs, # Additional keyword arguments to `show_batch`
):
"Steps through the transform pipeline for one batch, and optionally calls `show_batch(**kwargs)` on the transient `Dataloaders`."
print(f"Setting-up type transforms pipelines")
dsets = self.datasets(source, verbose=True)
print("\nBuilding one sample")
for tl in dsets.train.tls:
_apply_pipeline(tl.tfms, get_first(dsets.train.items))
print(f"\nFinal sample: {dsets.train[0]}\n\n")
dls = self.dataloaders(source, bs=bs, verbose=True)
print("\nBuilding one batch")
if len([f for f in dls.train.after_item.fs if f.name != 'noop'])!=0:
print("Applying item_tfms to the first sample:")
s = [_apply_pipeline(dls.train.after_item, dsets.train[0])]
print(f"\nAdding the next {bs-1} samples")
s += [dls.train.after_item(dsets.train[i]) for i in range(1, bs)]
else:
print("No item_tfms to apply")
s = [dls.train.after_item(dsets.train[i]) for i in range(bs)]
if len([f for f in dls.train.before_batch.fs if f.name != 'noop'])!=0:
print("\nApplying before_batch to the list of samples")
s = _apply_pipeline(dls.train.before_batch, s)
else: print("\nNo before_batch transform to apply")
print("\nCollating items in a batch")
try:
b = dls.train.create_batch(s)
b = retain_types(b, s[0] if is_listy(s) else s)
except Exception as e:
print("Error! It's not possible to collate your items in a batch")
why = _find_fail_collate(s)
print("Make sure all parts of your samples are tensors of the same size" if why is None else why)
raise e
if len([f for f in dls.train.after_batch.fs if f.name != 'noop'])!=0:
print("\nApplying batch_tfms to the batch built")
b = to_device(b, dls.device)
b = _apply_pipeline(dls.train.after_batch, b)
else: print("\nNo batch_tfms to apply")
if show_batch: dls.show_batch(**kwargs)
show_doc(DataBlock.summary)
DataBlock.summary
[source]
DataBlock.summary
(source
,bs
:int
=4
,show_batch
:bool
=False
, ****kwargs
**)
Steps through the transform pipeline for one batch, and optionally calls show_batch(**kwargs)
on the transient Dataloaders
.
Type | Default | Details | |
---|---|---|---|
source |
The data source | ||
bs |
int |
4 |
The batch size |
show_batch |
bool |
False |
Call show_batch after the summary |
kwargs |
No Content |
Besides stepping through the transformation, summary()
provides a shortcut dls.show_batch(...)
, to see the data. E.g.
pets.summary(path/"images", bs=8, show_batch=True, unique=True,...)
is a shortcut to:
pets.summary(path/"images", bs=8)
dls = pets.dataloaders(path/"images", bs=8)
dls.show_batch(unique=True,...) # See different tfms effect on the same image.
#|hide
from nbdev import nbdev_export
nbdev_export()
Converted 00_torch_core.ipynb. Converted 01_layers.ipynb. Converted 01a_losses.ipynb. Converted 02_data.load.ipynb. Converted 03_data.core.ipynb. Converted 04_data.external.ipynb. Converted 05_data.transforms.ipynb. Converted 06_data.block.ipynb. Converted 07_vision.core.ipynb. Converted 08_vision.data.ipynb. Converted 09_vision.augment.ipynb. Converted 09b_vision.utils.ipynb. Converted 09c_vision.widgets.ipynb. Converted 10_tutorial.pets.ipynb. Converted 10b_tutorial.albumentations.ipynb. Converted 11_vision.models.xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_callback.core.ipynb. Converted 13a_learner.ipynb. Converted 13b_metrics.ipynb. Converted 14_callback.schedule.ipynb. Converted 14a_callback.data.ipynb. Converted 15_callback.hook.ipynb. Converted 15a_vision.models.unet.ipynb. Converted 16_callback.progress.ipynb. Converted 17_callback.tracker.ipynb. Converted 18_callback.fp16.ipynb. Converted 18a_callback.training.ipynb. Converted 18b_callback.preds.ipynb. Converted 19_callback.mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision.learner.ipynb. Converted 22_tutorial.imagenette.ipynb. Converted 23_tutorial.vision.ipynb. Converted 24_tutorial.image_sequence.ipynb. Converted 24_tutorial.siamese.ipynb. Converted 24_vision.gan.ipynb. Converted 30_text.core.ipynb. Converted 31_text.data.ipynb. Converted 32_text.models.awdlstm.ipynb. Converted 33_text.models.core.ipynb. Converted 34_callback.rnn.ipynb. Converted 35_tutorial.wikitext.ipynb. Converted 37_text.learner.ipynb. Converted 38_tutorial.text.ipynb. Converted 39_tutorial.transformers.ipynb. Converted 40_tabular.core.ipynb. Converted 41_tabular.data.ipynb. Converted 42_tabular.model.ipynb. Converted 43_tabular.learner.ipynb. Converted 44_tutorial.tabular.ipynb. Converted 45_collab.ipynb. Converted 46_tutorial.collab.ipynb. Converted 50_tutorial.datablock.ipynb. Converted 60_medical.imaging.ipynb. Converted 61_tutorial.medical_imaging.ipynb. Converted 65_medical.text.ipynb. Converted 70_callback.wandb.ipynb. Converted 71_callback.tensorboard.ipynb. Converted 72_callback.neptune.ipynb. Converted 73_callback.captum.ipynb. Converted 74_callback.azureml.ipynb. Converted 97_test_utils.ipynb. Converted 99_pytorch_doc.ipynb. Converted dev-setup.ipynb. Converted app_examples.ipynb. Converted camvid.ipynb. Converted migrating_catalyst.ipynb. Converted migrating_ignite.ipynb. Converted migrating_lightning.ipynb. Converted migrating_pytorch.ipynb. Converted migrating_pytorch_verbose.ipynb. Converted ulmfit.ipynb. Converted index.ipynb. Converted index_original.ipynb. Converted quick_start.ipynb. Converted tutorial.ipynb.