#|hide
#| eval: false
! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab
#|export
from __future__ import annotations
from fastai.torch_basics import *
from fastai.callback.hook import *
#|hide
from nbdev.showdoc import *
#|default_exp vision.models.unet
Unet model using PixelShuffle ICNR upsampling that can be built on top of any pretrained architecture
#|export
def _get_sz_change_idxs(sizes):
"Get the indexes of the layers where the size of the activation changes."
feature_szs = [size[-1] for size in sizes]
sz_chg_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
return sz_chg_idxs
#|hide
test_eq(_get_sz_change_idxs([[3,64,64], [16,64,64], [32,32,32], [16,32,32], [32,32,32], [16,16]]), [1,4])
test_eq(_get_sz_change_idxs([[3,64,64], [16,32,32], [32,32,32], [16,32,32], [32,16,16], [16,16]]), [0,3])
test_eq(_get_sz_change_idxs([[3,64,64]]), [])
test_eq(_get_sz_change_idxs([[3,64,64], [16,32,32]]), [0])
#|export
class UnetBlock(Module):
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
@delegates(ConvLayer.__init__)
def __init__(self, up_in_c, x_in_c, hook, final_div=True, blur=False, act_cls=defaults.activation,
self_attention=False, init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
self.hook = hook
self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, act_cls=act_cls, norm_type=norm_type)
self.bn = BatchNorm(x_in_c)
ni = up_in_c//2 + x_in_c
nf = ni if final_div else ni//2
self.conv1 = ConvLayer(ni, nf, act_cls=act_cls, norm_type=norm_type, **kwargs)
self.conv2 = ConvLayer(nf, nf, act_cls=act_cls, norm_type=norm_type,
xtra=SelfAttention(nf) if self_attention else None, **kwargs)
self.relu = act_cls()
apply_init(nn.Sequential(self.conv1, self.conv2), init)
def forward(self, up_in):
s = self.hook.stored
up_out = self.shuf(up_in)
ssh = s.shape[-2:]
if ssh != up_out.shape[-2:]:
up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
return self.conv2(self.conv1(cat_x))
#|export
class ResizeToOrig(Module):
"Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`."
def __init__(self, mode='nearest'): self.mode = mode
def forward(self, x):
if x.orig.shape[-2:] != x.shape[-2:]:
x = F.interpolate(x, x.orig.shape[-2:], mode=self.mode)
return x
#|export
class DynamicUnet(SequentialEx):
"Create a U-Net from a given architecture."
def __init__(self, encoder, n_out, img_size, blur=False, blur_final=True, self_attention=False,
y_range=None, last_cross=True, bottle=False, act_cls=defaults.activation,
init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
imsize = img_size
sizes = model_sizes(encoder, size=imsize)
sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))
self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)
x = dummy_eval(encoder, imsize).detach()
ni = sizes[-1][1]
middle_conv = nn.Sequential(ConvLayer(ni, ni*2, act_cls=act_cls, norm_type=norm_type, **kwargs),
ConvLayer(ni*2, ni, act_cls=act_cls, norm_type=norm_type, **kwargs)).eval()
x = middle_conv(x)
layers = [encoder, BatchNorm(ni), nn.ReLU(), middle_conv]
for i,idx in enumerate(sz_chg_idxs):
not_final = i!=len(sz_chg_idxs)-1
up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])
do_blur = blur and (not_final or blur_final)
sa = self_attention and (i==len(sz_chg_idxs)-3)
unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=do_blur, self_attention=sa,
act_cls=act_cls, init=init, norm_type=norm_type, **kwargs).eval()
layers.append(unet_block)
x = unet_block(x)
ni = x.shape[1]
if imsize != sizes[0][-2:]: layers.append(PixelShuffle_ICNR(ni, act_cls=act_cls, norm_type=norm_type))
layers.append(ResizeToOrig())
if last_cross:
layers.append(MergeLayer(dense=True))
ni += in_channels(encoder)
layers.append(ResBlock(1, ni, ni//2 if bottle else ni, act_cls=act_cls, norm_type=norm_type, **kwargs))
layers += [ConvLayer(ni, n_out, ks=1, act_cls=None, norm_type=norm_type, **kwargs)]
apply_init(nn.Sequential(layers[3], layers[-2]), init)
#apply_init(nn.Sequential(layers[2]), init)
if y_range is not None: layers.append(SigmoidRange(*y_range))
layers.append(ToTensorBase())
super().__init__(*layers)
def __del__(self):
if hasattr(self, "sfs"): self.sfs.remove()
from fastai.vision.models import resnet34
m = resnet34()
m = nn.Sequential(*list(m.children())[:-2])
tst = DynamicUnet(m, 5, (128,128), norm_type=None)
x = cast(torch.randn(2, 3, 128, 128), TensorImage)
y = tst(x)
test_eq(y.shape, [2, 5, 128, 128])
tst = DynamicUnet(m, 5, (128,128), norm_type=None)
x = torch.randn(2, 3, 127, 128)
y = tst(x)
#|hide
from nbdev import *
nbdev_export()
Converted 00_torch_core.ipynb. Converted 01_layers.ipynb. Converted 01a_losses.ipynb. Converted 02_data.load.ipynb. Converted 03_data.core.ipynb. Converted 04_data.external.ipynb. Converted 05_data.transforms.ipynb. Converted 06_data.block.ipynb. Converted 07_vision.core.ipynb. Converted 08_vision.data.ipynb. Converted 09_vision.augment.ipynb. Converted 09b_vision.utils.ipynb. Converted 09c_vision.widgets.ipynb. Converted 10_tutorial.pets.ipynb. Converted 10b_tutorial.albumentations.ipynb. Converted 11_vision.models.xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_callback.core.ipynb. Converted 13a_learner.ipynb. Converted 13b_metrics.ipynb. Converted 14_callback.schedule.ipynb. Converted 14a_callback.data.ipynb. Converted 15_callback.hook.ipynb. Converted 15a_vision.models.unet.ipynb. Converted 16_callback.progress.ipynb. Converted 17_callback.tracker.ipynb. Converted 18_callback.fp16.ipynb. Converted 18a_callback.training.ipynb. Converted 18b_callback.preds.ipynb. Converted 19_callback.mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision.learner.ipynb. Converted 22_tutorial.imagenette.ipynb. Converted 23_tutorial.vision.ipynb. Converted 24_tutorial.image_sequence.ipynb. Converted 24_tutorial.siamese.ipynb. Converted 24_vision.gan.ipynb. Converted 30_text.core.ipynb. Converted 31_text.data.ipynb. Converted 32_text.models.awdlstm.ipynb. Converted 33_text.models.core.ipynb. Converted 34_callback.rnn.ipynb. Converted 35_tutorial.wikitext.ipynb. Converted 36_text.models.qrnn.ipynb. Converted 37_text.learner.ipynb. Converted 38_tutorial.text.ipynb. Converted 39_tutorial.transformers.ipynb. Converted 40_tabular.core.ipynb. Converted 41_tabular.data.ipynb. Converted 42_tabular.model.ipynb. Converted 43_tabular.learner.ipynb. Converted 44_tutorial.tabular.ipynb. Converted 45_collab.ipynb. Converted 46_tutorial.collab.ipynb. Converted 50_tutorial.datablock.ipynb. Converted 60_medical.imaging.ipynb. Converted 61_tutorial.medical_imaging.ipynb. Converted 65_medical.text.ipynb. Converted 70_callback.wandb.ipynb. Converted 71_callback.tensorboard.ipynb. Converted 72_callback.neptune.ipynb. Converted 73_callback.captum.ipynb. Converted 74_callback.azureml.ipynb. Converted 97_test_utils.ipynb. Converted 99_pytorch_doc.ipynb. Converted dev-setup.ipynb. Converted index.ipynb. Converted quick_start.ipynb. Converted tutorial.ipynb.