# default_exp distributed
#export
from local.basics import *
from local.callback.progress import ProgressCallback
from torch.nn.parallel import DistributedDataParallel, DataParallel
from torch.utils.data.distributed import DistributedSampler
from local.test import *
Callbacks and helper functions to train in parallel or use distributed training
Patch the parallel models so they work with RNNs
#export
@patch
def reset(self: DataParallel):
if hasattr(self.module, 'reset'): self.module.reset()
#export
class ParallelTrainer(Callback):
run_after,run_before = TrainEvalCallback,Recorder
def __init__(self, device_ids): self.device_ids = device_ids
def begin_fit(self): self.learn.model = DataParallel(self.learn.model, device_ids=self.device_ids)
def after_fit(self): self.learn.model = self.learn.model.module
#export
@patch
def to_parallel(self: Learner, device_ids=None):
self.add_cb(ParallelTrainer(device_ids))
return self
Patch the parallel models so they work with RNNs
#export
@patch
def reset(self: DistributedDataParallel):
if hasattr(self.module, 'reset'): self.module.reset()
#export
def setup_distrib(gpu=None):
if gpu is None: return gpu
gpu = int(gpu)
torch.cuda.set_device(int(gpu))
if num_distrib() > 1:
torch.distributed.init_process_group(backend='nccl', init_method='env://')
return gpu
We need to change the dataloaders so that they only get one part of the batch each (otherwise tehre is not point in using distributed training).
#export
@delegates()
class DistributedDL(TfmdDL):
def __init__(self, dataset, rank, world_size, **kwargs):
super().__init__(dataset, **kwargs)
if self.n%world_size != 0: self.n += world_size-self.n%world_size
self.total_n,self.n = self.n,self.n//world_size
store_attr(self, 'rank,world_size')
def get_idxs(self):
idxs = Inf.count if self.indexed else Inf.nones
return idxs if self.n is None else list(itertools.islice(idxs, self.total_n))
def shuffle_fn(self, idxs):
"Deterministically shuffle on each training process based on epoch."
g = torch.Generator()
g.manual_seed(self.epoch)
return L(idxs)[torch.randperm(self.total_n, generator=g)]
def sample(self):
idxs = self.get_idxs()
if self.shuffle: idxs = self.shuffle_fn(idxs)
# add extra samples to make it evenly divisible
idxs += idxs[:(self.total_n - len(idxs))]
# subsample
idxs = idxs[self.rank:self.total_n:self.world_size]
return (b for i,b in enumerate(idxs) if i//(self.bs or 1)%self.nw==self.offs)
def create_item(self, s):
if s is not None and s >= len(self.dataset): s = s%len(self.dataset)
return super().create_item(s)
def set_epoch(self, epoch): self.epoch = epoch
@classmethod
def from_dl(cls, dl, rank, world_size, **kwargs):
cur_kwargs = dict(num_workers=dl.fake_l.num_workers, pin_memory=dl.pin_memory, timeout=dl.timeout,
bs=dl.bs, shuffle=dl.shuffle, drop_last=dl.drop_last, indexed=dl.indexed)
cur_kwargs.update({n: getattr(dl, n) for n in cls._methods if n not in "sample shuffle_fn create_item".split()})
return cls(dl.dataset, rank, world_size, **merge(cur_kwargs, kwargs))
dl = TfmdDL(list(range(50)), bs=16, num_workers=2)
for i in range(4):
dl1 = DistributedDL.from_dl(dl, i, 4)
test_eq(list(dl1)[0], torch.arange(i, 52, 4)%50)
dl = TfmdDL(list(range(50)), bs=16, num_workers=2, shuffle=True)
res = []
for i in range(4):
dl1 = DistributedDL.from_dl(dl, i, 4)
dl1.set_epoch(0)
res += list(dl1)[0].tolist()
#All items should only be accessed once (except 0 and 1 for final cycle) with seeded shuffle
test_eq(sorted(res), [0,0,1,1] + list(range(2, 50)))
#export
class DistributedTrainer(Callback):
run_after,run_before = TrainEvalCallback,Recorder
def __init__(self, cuda_id=0): self.cuda_id = cuda_id
def begin_fit(self):
self.learn.model = DistributedDataParallel(self.model, device_ids=[self.cuda_id], output_device=self.cuda_id)
self.old_dls = [dl for dl in self.dbunch.dls]
self.learn.dbunch.dls = [DistributedDL.from_dl(dl, rank_distrib(), num_distrib()) for dl in self.dbunch.dls]
if rank_distrib() > 0: self.learn.logger=noop
def begin_epoch(self):
for dl in self.dbunch.dls: dl.set_epoch(self.epoch)
def after_fit(self):
self.learn.model = self.learn.model.module
self.learn.dbunch.dls = self.old_dls
#export
@patch
def to_distributed(self: Learner, cuda_id):
self.add_cb(DistributedTrainer(cuda_id))
if rank_distrib() > 0: self.remove_cb(self.progress)
return self
#hide
from local.notebook.export import notebook2script
notebook2script(all_fs=True)
Converted 00_test.ipynb. Converted 01_core.ipynb. Converted 01a_utils.ipynb. Converted 01b_dispatch.ipynb. Converted 01c_transform.ipynb. Converted 02_script.ipynb. Converted 03_torch_core.ipynb. Converted 03a_layers.ipynb. Converted 04_dataloader.ipynb. Converted 05_data_core.ipynb. Converted 06_data_transforms.ipynb. Converted 07_data_block.ipynb. Converted 08_vision_core.ipynb. Converted 09_vision_augment.ipynb. Converted 09a_vision_data.ipynb. Converted 10_pets_tutorial.ipynb. Converted 11_vision_models_xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_learner.ipynb. Converted 13a_metrics.ipynb. Converted 14_callback_schedule.ipynb. Converted 14a_callback_data.ipynb. Converted 15_callback_hook.ipynb. Converted 15a_vision_models_unet.ipynb. Converted 16_callback_progress.ipynb. Converted 17_callback_tracker.ipynb. Converted 18_callback_fp16.ipynb. Converted 19_callback_mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision_learner.ipynb. Converted 22_tutorial_imagenette.ipynb. Converted 23_tutorial_transfer_learning.ipynb. Converted 30_text_core.ipynb. Converted 31_text_data.ipynb. Converted 32_text_models_awdlstm.ipynb. Converted 33_text_models_core.ipynb. Converted 34_callback_rnn.ipynb. Converted 35_tutorial_wikitext.ipynb. Converted 36_text_models_qrnn.ipynb. Converted 37_text_learner.ipynb. Converted 38_tutorial_ulmfit.ipynb. Converted 40_tabular_core.ipynb. Converted 41_tabular_model.ipynb. Converted 42_tabular_rapids.ipynb. Converted 50_data_block_examples.ipynb. Converted 60_medical_imaging.ipynb. Converted 65_medical_text.ipynb. Converted 90_notebook_core.ipynb. Converted 91_notebook_export.ipynb. Converted 92_notebook_showdoc.ipynb. Converted 93_notebook_export2html.ipynb. Converted 94_notebook_test.ipynb. Converted 95_index.ipynb. Converted 96_data_external.ipynb. Converted 97_utils_test.ipynb. Converted notebook2jekyll.ipynb.