# default_exp layers
# default_cls_lvl 3
#export
from local.core.all import *
from local.torch_imports import *
from local.torch_core import *
from local.test import *
from torch.nn.utils import weight_norm, spectral_norm
from local.notebook.showdoc import *
Custom fastai layers and basic functions to grab them.
#export
class Identity(Module):
"Do nothing at all"
def forward(self,x): return x
# export
class Lambda(Module):
"An easy way to create a pytorch layer for a simple `func`"
def __init__(self, func): self.func=func
def forward(self, x): return self.func(x)
def __repr__(self): return f'{self.__class__.__name__}({self.func})'
Warning: In the tests below, we use lambda functions for convenience, but you shouldn't do this when building a real modules as it would make models that won't pickle (so you won't be able to save/export them).
tst = Lambda(lambda x:x+2)
x = torch.randn(10,20)
test_eq(tst(x), x+2)
# export
class PartialLambda(Lambda):
"Layer that applies `partial(func, **kwargs)`"
def __init__(self, func, **kwargs):
super().__init__(partial(func, **kwargs))
self.repr = f'{func.__name__}, {kwargs}'
def forward(self, x): return self.func(x)
def __repr__(self): return f'{self.__class__.__name__}({self.repr})'
def test_func(a,b=2): return a+b
tst = PartialLambda(test_func, b=5)
test_eq(tst(x), x+5)
# export
class View(Module):
"Reshape `x` to `size`"
def __init__(self, *size): self.size = size
def forward(self, x): return x.view(self.size)
tst = View(10,5,4)
test_eq(tst(x).shape, [10,5,4])
# export
class ResizeBatch(Module):
"Reshape `x` to `size`, keeping batch dim the same size"
def __init__(self, *size): self.size = size
def forward(self, x): return x.view((x.size(0),) + self.size)
tst = ResizeBatch(5,4)
test_eq(tst(x).shape, [10,5,4])
# export
class Flatten(Module):
"Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor"
def __init__(self, full=False): self.full = full
def forward(self, x): return x.view(-1) if self.full else x.view(x.size(0), -1)
tst = Flatten()
x = torch.randn(10,5,4)
test_eq(tst(x).shape, [10,20])
tst = Flatten(full=True)
test_eq(tst(x).shape, [200])
# export
class Debugger(nn.Module):
"A module to debug inside a model."
def forward(self,x):
set_trace()
return x
# export
def sigmoid_range(x, low, high):
"Sigmoid function with range `(low, high)`"
return torch.sigmoid(x) * (high - low) + low
test = tensor([-10.,0.,10.])
assert torch.allclose(sigmoid_range(test, -1, 2), tensor([-1.,0.5, 2.]), atol=1e-4, rtol=1e-4)
assert torch.allclose(sigmoid_range(test, -5, -1), tensor([-5.,-3.,-1.]), atol=1e-4, rtol=1e-4)
assert torch.allclose(sigmoid_range(test, 2, 4), tensor([2., 3., 4.]), atol=1e-4, rtol=1e-4)
# export
class SigmoidRange(Module):
"Sigmoid module with range `(low, high)`"
def __init__(self, low, high): self.low,self.high = low,high
def forward(self, x): return sigmoid_range(x, self.low, self.high)
tst = SigmoidRange(-1, 2)
assert torch.allclose(tst(test), tensor([-1.,0.5, 2.]), atol=1e-4, rtol=1e-4)
# export
class AdaptiveConcatPool2d(nn.Module):
"Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`"
def __init__(self, size=None):
super().__init__()
self.size = size or 1
self.ap = nn.AdaptiveAvgPool2d(self.size)
self.mp = nn.AdaptiveMaxPool2d(self.size)
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
If the input is bs x nf x h x h
, the output will be bs x 2*nf x 1 x 1
if no size is passed or bs x 2*nf x size x size
tst = AdaptiveConcatPool2d()
x = torch.randn(10,5,4,4)
test_eq(tst(x).shape, [10,10,1,1])
max1 = torch.max(x, dim=2, keepdim=True)[0]
maxp = torch.max(max1, dim=3, keepdim=True)[0]
test_eq(tst(x)[:,:5], maxp)
test_eq(tst(x)[:,5:], x.mean(dim=[2,3], keepdim=True))
tst = AdaptiveConcatPool2d(2)
test_eq(tst(x).shape, [10,10,2,2])
# export
mk_class('PoolType', **{o:o for o in 'Avg Max Cat'.split()})
# export
_all_ = ['PoolType']
#export
def pool_layer(pool_type):
return nn.AdaptiveAvgPool2d if pool_type=='Avg' else nn.AdaptiveMaxPool2d if pool_type=='Max' else AdaptiveConcatPool2d
# export
class PoolFlatten(nn.Sequential):
"Combine `nn.AdaptiveAvgPool2d` and `Flatten`."
def __init__(self, pool_type=PoolType.Avg): super().__init__(pool_layer(pool_type)(1), Flatten())
tst = PoolFlatten()
test_eq(tst(x).shape, [10,5])
test_eq(tst(x), x.mean(dim=[2,3]))
# export
NormType = Enum('NormType', 'Batch BatchZero Weight Spectral')
#export
def BatchNorm(nf, norm_type=NormType.Batch, ndim=2, **kwargs):
"BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`."
assert 1 <= ndim <= 3
bn = getattr(nn, f"BatchNorm{ndim}d")(nf, **kwargs)
bn.bias.data.fill_(1e-3)
bn.weight.data.fill_(0. if norm_type==NormType.BatchZero else 1.)
return bn
kwargs
are passed to nn.BatchNorm
and can be eps
, momentum
, affine
and track_running_stats
.
tst = BatchNorm(15)
assert isinstance(tst, nn.BatchNorm2d)
test_eq(tst.weight, torch.ones(15))
tst = BatchNorm(15, norm_type=NormType.BatchZero)
test_eq(tst.weight, torch.zeros(15))
tst = BatchNorm(15, ndim=1)
assert isinstance(tst, nn.BatchNorm1d)
tst = BatchNorm(15, ndim=3)
assert isinstance(tst, nn.BatchNorm3d)
# export
class BatchNorm1dFlat(nn.BatchNorm1d):
"`nn.BatchNorm1d`, but first flattens leading dimensions"
def forward(self, x):
if x.dim()==2: return super().forward(x)
*f,l = x.shape
x = x.contiguous().view(-1,l)
return super().forward(x).view(*f,l)
tst = BatchNorm1dFlat(15)
x = torch.randn(32, 64, 15)
y = tst(x)
mean = x.mean(dim=[0,1])
test_close(tst.running_mean, 0*0.9 + mean*0.1)
var = (x-mean).pow(2).mean(dim=[0,1])
test_close(tst.running_var, 1*0.9 + var*0.1, eps=1e-4)
test_close(y, (x-mean)/torch.sqrt(var+1e-5) * tst.weight + tst.bias, eps=1e-4)
# export
class LinBnDrop(nn.Sequential):
"Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers"
def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=False):
layers = [BatchNorm(n_out if lin_first else n_in, ndim=1)] if bn else []
if p != 0: layers.append(nn.Dropout(p))
lin = [nn.Linear(n_in, n_out, bias=not bn)]
if act is not None: lin.append(act)
layers = lin+layers if lin_first else layers+lin
super().__init__(*layers)
The BatchNorm
layer is skipped if bn=False
, as is the dropout if p=0.
. Optionally, you can add an activation for after the linear laeyr with act
.
tst = LinBnDrop(10, 20)
mods = list(tst.children())
test_eq(len(mods), 2)
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Linear)
tst = LinBnDrop(10, 20, p=0.1)
mods = list(tst.children())
test_eq(len(mods), 3)
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Dropout)
assert isinstance(mods[2], nn.Linear)
tst = LinBnDrop(10, 20, act=nn.ReLU(), lin_first=True)
mods = list(tst.children())
test_eq(len(mods), 3)
assert isinstance(mods[0], nn.Linear)
assert isinstance(mods[1], nn.ReLU)
assert isinstance(mods[2], nn.BatchNorm1d)
tst = LinBnDrop(10, 20, bn=False)
mods = list(tst.children())
test_eq(len(mods), 1)
assert isinstance(mods[0], nn.Linear)
#export
def init_default(m, func=nn.init.kaiming_normal_):
"Initialize `m` weights with `func` and set `bias` to 0."
if func and hasattr(m, 'weight'): func(m.weight)
with torch.no_grad():
if getattr(m, 'bias', None) is not None: m.bias.fill_(0.)
return m
#export
def _conv_func(ndim=2, transpose=False):
"Return the proper conv `ndim` function, potentially `transposed`."
assert 1 <= ndim <=3
return getattr(nn, f'Conv{"Transpose" if transpose else ""}{ndim}d')
#hide
test_eq(_conv_func(ndim=1),torch.nn.modules.conv.Conv1d)
test_eq(_conv_func(ndim=2),torch.nn.modules.conv.Conv2d)
test_eq(_conv_func(ndim=3),torch.nn.modules.conv.Conv3d)
test_eq(_conv_func(ndim=1, transpose=True),torch.nn.modules.conv.ConvTranspose1d)
test_eq(_conv_func(ndim=2, transpose=True),torch.nn.modules.conv.ConvTranspose2d)
test_eq(_conv_func(ndim=3, transpose=True),torch.nn.modules.conv.ConvTranspose3d)
# export
defaults.activation=nn.ReLU
# export
class ConvLayer(nn.Sequential):
"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and `norm_type` layers."
def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=NormType.Batch, bn_1st=True,
act_cls=defaults.activation, transpose=False, init=nn.init.kaiming_normal_, xtra=None, **kwargs):
if padding is None: padding = ((ks-1)//2 if not transpose else 0)
bn = norm_type in (NormType.Batch, NormType.BatchZero)
if bias is None: bias = not bn
conv_func = _conv_func(ndim, transpose=transpose)
conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs), init)
if norm_type==NormType.Weight: conv = weight_norm(conv)
elif norm_type==NormType.Spectral: conv = spectral_norm(conv)
layers = [conv]
act_bn = []
if act_cls is not None: act_bn.append(act_cls())
if bn: act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim))
if bn_1st: act_bn.reverse()
layers += act_bn
if xtra: layers.append(xtra)
super().__init__(*layers)
The convolution uses ks
(kernel size) stride
, padding
and bias
. padding
will default to the appropriate value ((ks-1)//2
if it's not a transposed conv) and bias
will default to True
the norm_type
is Spectral
or Weight
, False
if it's Batch
or BatchZero
. Note that if you don't want any normalization, you should pass norm_type=None
.
This defines a conv layer with ndim
(1,2 or 3) that will be a ConvTranspose if transpose=True
. act_cls
is the class of the activation function to use (instantiated inside). Pass act=None
if you don't want an activation function. If you quickly want to change your default activation, you can change the value of defaults.activation
.
init
is used to initialize the weights (the bias are initiliazed to 0) and xtra
is an optional layer to add at the end.
tst = ConvLayer(16, 32)
mods = list(tst.children())
test_eq(len(mods), 3)
test_eq(mods[1].weight, torch.ones(32))
test_eq(mods[0].padding, (1,1))
x = torch.randn(64, 16, 8, 8)#.cuda()
# tst = tst.cuda()
#Padding is selected to make the shape the same if stride=1
test_eq(tst(x).shape, [64,32,8,8])
#Padding is selected to make the shape half if stride=2
tst = ConvLayer(16, 32, stride=2)
test_eq(tst(x).shape, [64,32,4,4])
#But you can always pass your own padding if you want
tst = ConvLayer(16, 32, padding=0)
test_eq(tst(x).shape, [64,32,6,6])
#No bias by default for Batch NormType
assert mods[0].bias is None
#But can be overriden with `bias=True`
tst = ConvLayer(16, 32, bias=True)
test_eq(list(tst.children())[0].bias, torch.zeros(32))
#For no norm, or spectral/weight, bias is True by default
for t in [None, NormType.Spectral, NormType.Weight]:
tst = ConvLayer(16, 32, norm_type=t)
test_eq(list(tst.children())[0].bias, torch.zeros(32))
#Various n_dim/tranpose
tst = ConvLayer(16, 32, ndim=3)
assert isinstance(list(tst.children())[0], nn.Conv3d)
tst = ConvLayer(16, 32, ndim=1, transpose=True)
assert isinstance(list(tst.children())[0], nn.ConvTranspose1d)
#No activation/leaky
tst = ConvLayer(16, 32, ndim=3, act_cls=None)
mods = list(tst.children())
test_eq(len(mods), 2)
tst = ConvLayer(16, 32, ndim=3, act_cls=partial(nn.LeakyReLU, negative_slope=0.1))
mods = list(tst.children())
test_eq(len(mods), 3)
assert isinstance(mods[2], nn.LeakyReLU)
nn.MaxPool2d
torch.nn.modules.pooling.MaxPool2d
#export
def AdaptiveAvgPool(sz=1, ndim=2):
"nn.AdaptiveAvgPool layer for `ndim`"
assert 1 <= ndim <= 3
return getattr(nn, f"AdaptiveAvgPool{ndim}d")(sz)
#export
def MaxPool(ks=2, stride=None, padding=0, ndim=2):
"nn.MaxPool layer for `ndim`"
assert 1 <= ndim <= 3
return getattr(nn, f"MaxPool{ndim}d")(ks, stride=stride, padding=padding)
#export
def AvgPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):
"nn.AvgPool layer for `ndim`"
assert 1 <= ndim <= 3
return getattr(nn, f"AvgPool{ndim}d")(ks, stride=stride, padding=padding, ceil_mode=ceil_mode)
The following class if the base class to warp a loss function it provides several added functionality:
axis
at the end)activation
method that tells the library if there is an activation fused in the loss (useful for inference and methods such as Learner.get_preds
or Learner.predict
)decodes
method that is used on predictions in inference (for instance, an argmax in classification)F.binary_cross_entropy_with_logits(torch.randn(4,5), torch.randint(0, 2, (4,5)).float(), reduction='none')
tensor([[0.4297, 0.7398, 0.6388, 0.5635, 0.3743], [0.5245, 1.0219, 0.9097, 0.9432, 0.8671], [0.5316, 0.2025, 0.7467, 2.2852, 1.8450], [0.5773, 0.8334, 1.0466, 0.9615, 0.1678]])
# export
@funcs_kwargs
class BaseLoss():
"Same as `loss_cls`, but flattens input and target."
activation=decodes=noops
_methods = "activation decodes".split()
def __init__(self, loss_cls, *args, axis=-1, flatten=True, floatify=False, is_2d=True, **kwargs):
store_attr(self, "axis,flatten,floatify,is_2d")
self.func = loss_cls(*args,**kwargs)
functools.update_wrapper(self, self.func)
def __repr__(self): return f"FlattenedLoss of {self.func}"
@property
def reduction(self): return self.func.reduction
@reduction.setter
def reduction(self, v): self.func.reduction = v
def __call__(self, inp, targ, **kwargs):
inp = inp .transpose(self.axis,-1).contiguous()
targ = targ.transpose(self.axis,-1).contiguous()
if self.floatify and targ.dtype!=torch.float16: targ = targ.float()
if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
The args
and kwargs
will be passed to loss_cls
during the initialization to instantiate a loss function. axis
is put at the end for losses like softmax that are often performed on the last axis. If floatify=True
the targs will be converted to float (usefull for losses that only accept float targets like BCEWithLogitsLoss
) and is_2d
determines if we flatten while keeping the first dimension (batch size) or completely flatten the input. We want the first for losses like Cross Entropy, and the second for pretty much anything else.
# export
@delegates(keep=True)
class CrossEntropyLossFlat(BaseLoss):
"Same as `nn.CrossEntropyLoss`, but flattens input and target."
y_int = True
def __init__(self, *args, axis=-1, **kwargs): super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
def decodes(self, x): return x.argmax(dim=self.axis)
def activation(self, x): return F.softmax(x, dim=self.axis)
tst = CrossEntropyLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randint(0, 10, (32,5))
#nn.CrossEntropy would fail with those two tensors, but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.CrossEntropyLoss()(output,target))
#Associated activation is softmax
test_eq(tst.activation(output), F.softmax(output, dim=-1))
#This loss function has a decodes which is argmax
test_eq(tst.decodes(output), output.argmax(dim=-1))
#In a segmentation task, we want to take the softmax over the channel dimension
tst = CrossEntropyLossFlat(axis=1)
output = torch.randn(32, 5, 128, 128)
target = torch.randint(0, 5, (32, 128, 128))
_ = tst(output, target)
test_eq(tst.activation(output), F.softmax(output, dim=1))
test_eq(tst.decodes(output), output.argmax(dim=1))
# export
@delegates(keep=True)
class BCEWithLogitsLossFlat(BaseLoss):
"Same as `nn.CrossEntropyLoss`, but flattens input and target."
def __init__(self, *args, axis=-1, floatify=True, thresh=0.5, **kwargs):
super().__init__(nn.BCEWithLogitsLoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
self.thresh = thresh
def decodes(self, x): return x>self.thresh
def activation(self, x): return torch.sigmoid(x)
tst = BCEWithLogitsLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randn(32, 5, 10)
#nn.BCEWithLogitsLoss would fail with those two tensors, but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))
output = torch.randn(32, 5)
target = torch.randint(0,2,(32, 5))
#nn.BCEWithLogitsLoss would fail with int targets but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))
#Associated activation is sigmoid
test_eq(tst.activation(output), torch.sigmoid(output))
# export
def BCELossFlat(*args, axis=-1, floatify=True, **kwargs):
"Same as `nn.BCELoss`, but flattens input and target."
return BaseLoss(nn.BCELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
tst = BCELossFlat()
output = torch.sigmoid(torch.randn(32, 5, 10))
target = torch.randint(0,2,(32, 5, 10))
_ = tst(output, target)
test_fail(lambda x: nn.BCELoss()(output,target))
# export
def MSELossFlat(*args, axis=-1, floatify=True, **kwargs):
"Same as `nn.MSELoss`, but flattens input and target."
return BaseLoss(nn.MSELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
tst = MSELossFlat()
output = torch.sigmoid(torch.randn(32, 5, 10))
target = torch.randint(0,2,(32, 5, 10))
_ = tst(output, target)
test_fail(lambda x: nn.MSELoss()(output,target))
#hide
#cuda
#Test losses work in half precision
output = torch.sigmoid(torch.randn(32, 5, 10)).half().cuda()
target = torch.randint(0,2,(32, 5, 10)).half().cuda()
for tst in [BCELossFlat(), MSELossFlat()]: _ = tst(output, target)
#export
class LabelSmoothingCrossEntropy(Module):
y_int = True
def __init__(self, eps:float=0.1, reduction='mean'): self.eps,self.reduction = eps,reduction
def forward(self, output, target):
c = output.size()[-1]
log_preds = F.log_softmax(output, dim=-1)
if self.reduction=='sum': loss = -log_preds.sum()
else:
loss = -log_preds.sum(dim=-1)
if self.reduction=='mean': loss = loss.mean()
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction)
def activation(self, out): return F.softmax(out, dim=-1)
def decodes(self, out): return out.argmax(dim=-1)
On top of the formula we define:
reduction
attribute, that will be used when we call Learner.get_preds
activation
function that represents the activation fused in the loss (since we use cross entropy behind the scenes). It will be applied to the output of the model when calling Learner.get_preds
or Learner.predict
decodes
function that converts the output of the model to a format similar to the target (here indices). This is used in Learner.predict
and Learner.show_results
to decode the predictions# export
def trunc_normal_(x, mean=0., std=1.):
"Truncated normal initialization (approximation)"
# From https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/12
return x.normal_().fmod_(2).mul_(std).add_(mean)
# export
class Embedding(nn.Embedding):
"Embedding layer with truncated normal initialization"
def __init__(self, ni, nf):
super().__init__(ni, nf)
trunc_normal_(self.weight.data, std=0.01)
Truncated normal initialization bounds the distribution to avoid large value. For a given standard deviation std
, the bounds are roughly -std
, std
.
tst = Embedding(10, 30)
assert tst.weight.min() > -0.02
assert tst.weight.max() < 0.02
test_close(tst.weight.mean(), 0, 1e-2)
test_close(tst.weight.std(), 0.01, 0.1)
# export
class SelfAttention(nn.Module):
"Self attention layer for `n_channels`."
def __init__(self, n_channels):
super().__init__()
self.query,self.key,self.value = [self._conv(n_channels, c) for c in (n_channels//8,n_channels//8,n_channels)]
self.gamma = nn.Parameter(tensor([0.]))
def _conv(self,n_in,n_out):
return ConvLayer(n_in, n_out, ks=1, ndim=1, norm_type=NormType.Spectral, act_cls=None, bias=False)
def forward(self, x):
#Notation from the paper.
size = x.size()
x = x.view(*size[:2],-1)
f,g,h = self.query(x),self.key(x),self.value(x)
beta = F.softmax(torch.bmm(f.transpose(1,2), g), dim=1)
o = self.gamma * torch.bmm(h, beta) + x
return o.view(*size).contiguous()
Self-attention layer as introduced in Self-Attention Generative Adversarial Networks.
Initially, no change is done to the input. This is controlled by a trainable parameter named gamma
as we return x + gamma * out
.
tst = SelfAttention(16)
x = torch.randn(32, 16, 8, 8)
test_eq(tst(x),x)
Then during training gamma
will probably change since it's a trainable parameter. Let's see what's hapenning when it gets a nonzero value.
tst.gamma.data.fill_(1.)
y = tst(x)
test_eq(y.shape, [32,16,8,8])
The attention mechanism requires three matrix multiplications (here represented by 1x1 convs). The multiplications are done on the channel level (the second dimension in our tensor) and we flatten the feature map (which is 8x8 here). As in the paper, we note f
, g
and h
the results of those multiplications.
q,k,v = tst.query[0].weight.data,tst.key[0].weight.data,tst.value[0].weight.data
test_eq([q.shape, k.shape, v.shape], [[2, 16, 1], [2, 16, 1], [16, 16, 1]])
f,g,h = map(lambda m: x.view(32, 16, 64).transpose(1,2) @ m.squeeze().t(), [q,k,v])
test_eq([f.shape, g.shape, h.shape], [[32,64,2], [32,64,2], [32,64,16]])
The key part of the attention layer is to compute attention weights for each of our location in the feature map (here 8x8 = 64). Those are positive numbers that sum to 1 and tell the model to pay attention to this or that part of the picture. We make the product of f
and the transpose of g
(to get something of size bs by 64 by 64) then apply a softmax on the first dimension (to get the positive numbers that sum up to 1). The result can then be multiplied with h
transposed to get an output of size bs by channels by 64, which we can then be viewed as an output the same size as the original input.
The final result is then x + gamma * out
as we saw before.
beta = F.softmax(torch.bmm(f, g.transpose(1,2)), dim=1)
test_eq(beta.shape, [32, 64, 64])
out = torch.bmm(h.transpose(1,2), beta)
test_eq(out.shape, [32, 16, 64])
test_close(y, x + out.view(32, 16, 8, 8), eps=1e-4)
# export
class PooledSelfAttention2d(nn.Module):
"Pooled self attention layer for 2d."
def __init__(self, n_channels):
super().__init__()
self.n_channels = n_channels
self.query,self.key,self.value = [self._conv(n_channels, c) for c in (n_channels//8,n_channels//8,n_channels//2)]
self.out = self._conv(n_channels//2, n_channels)
self.gamma = nn.Parameter(tensor([0.]))
def _conv(self,n_in,n_out):
return ConvLayer(n_in, n_out, ks=1, norm_type=NormType.Spectral, act_cls=None, bias=False)
def forward(self, x):
n_ftrs = x.shape[2]*x.shape[3]
f = self.query(x).view(-1, self.n_channels//8, n_ftrs)
g = F.max_pool2d(self.key(x), [2,2]).view(-1, self.n_channels//8, n_ftrs//4)
h = F.max_pool2d(self.value(x), [2,2]).view(-1, self.n_channels//2, n_ftrs//4)
beta = F.softmax(torch.bmm(f.transpose(1, 2), g), -1)
o = self.out(torch.bmm(h, beta.transpose(1,2)).view(-1, self.n_channels//2, x.shape[2], x.shape[3]))
return self.gamma * o + x
Self-attention layer used in the Big GAN paper.
It uses the same attention as in SelfAttention
but adds a max pooling of stride 2 before computing the matrices g
and h
: the attention is ported on one of the 2x2 max-pooled window, not the whole feature map. There is also a final matrix product added at the end to the output, before retuning gamma * out + x
.
#export
def _conv1d_spect(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False):
"Create and initialize a `nn.Conv1d` layer with spectral normalization."
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
nn.init.kaiming_normal_(conv.weight)
if bias: conv.bias.data.zero_()
return spectral_norm(conv)
#export
class SimpleSelfAttention(Module):
def __init__(self, n_in:int, ks=1, sym=False):
self.sym,self.n_in = sym,n_in
self.conv = _conv1d_spect(n_in, n_in, ks, padding=ks//2, bias=False)
self.gamma = nn.Parameter(tensor([0.]))
def forward(self,x):
if self.sym:
c = self.conv.weight.view(self.n_in,self.n_in)
c = (c + c.t())/2
self.conv.weight = c.view(self.n_in,self.n_in,1)
size = x.size()
x = x.view(*size[:2],-1)
convx = self.conv(x)
xxT = torch.bmm(x,x.permute(0,2,1).contiguous())
o = torch.bmm(xxT, convx)
o = self.gamma * o + x
return o.view(*size).contiguous()
PixelShuffle introduced in this article to avoid checkerboard artifacts when upsampling images. If we want an output with ch_out
filters, we use a convolution with ch_out * (r**2)
filters, where r
is the upsampling factor. Then we reorganize those filters like in the picture below:
# export
def icnr_init(x, scale=2, init=nn.init.kaiming_normal_):
"ICNR init of `x`, with `scale` and `init` function"
ni,nf,h,w = x.shape
ni2 = int(ni/(scale**2))
k = init(x.new_zeros([ni2,nf,h,w])).transpose(0, 1)
k = k.contiguous().view(ni2, nf, -1)
k = k.repeat(1, 1, scale**2)
return k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
ICNR init was introduced in this article. It suggests to initialize the convolution that will be used in PixelShuffle so that each of the r**2
channels get the same weight (so that in the picture above, the 9 colors in a 3 by 3 window are initially the same).
Note: This is done on the first dimension because PyTorch stores the weights of a convolutional layer in this format:
ch_out x ch_in x ks x ks
.
tst = torch.randn(16*4, 32, 1, 1)
tst = icnr_init(tst)
for i in range(0,16*4,4):
test_eq(tst[i],tst[i+1])
test_eq(tst[i],tst[i+2])
test_eq(tst[i],tst[i+3])
# export
class PixelShuffle_ICNR(nn.Sequential):
"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`."
def __init__(self, ni, nf=None, scale=2, blur=False, norm_type=NormType.Weight, act_cls=defaults.activation):
super().__init__()
nf = ifnone(nf, ni)
layers = [ConvLayer(ni, nf*(scale**2), ks=1, norm_type=norm_type, act_cls=act_cls),
nn.PixelShuffle(scale)]
layers[0][0].weight.data.copy_(icnr_init(layers[0][0].weight.data))
if blur: layers += [nn.ReplicationPad2d((1,0,1,0)), nn.AvgPool2d(2, stride=1)]
super().__init__(*layers)
The convolutional layer is initialized with icnr_init
and passed act_cls
and norm_type
(the default of weight normalization seemed to be what's best for super-resolution problems, in our experiments).
The blur
option comes from Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts where the authors add a little bit of blur to completely get rid of checkerboard artifacts.
psfl = PixelShuffle_ICNR(16, norm_type=None) #Deactivate weight norm as it changes the weight
x = torch.randn(64, 16, 8, 8)
y = psfl(x)
test_eq(y.shape, [64, 16, 16, 16])
#ICNR init makes every 2x2 window (stride 2) have the same elements
for i in range(0,16,2):
for j in range(0,16,2):
test_eq(y[:,:,i,j],y[:,:,i+1,j])
test_eq(y[:,:,i,j],y[:,:,i ,j+1])
test_eq(y[:,:,i,j],y[:,:,i+1,j+1])
# export
class SequentialEx(Module):
"Like `nn.Sequential`, but with ModuleList semantics, and can access module input"
def __init__(self, *layers): self.layers = nn.ModuleList(layers)
def forward(self, x):
res = x
for l in self.layers:
res.orig = x
nres = l(res)
# We have to remove res.orig to avoid hanging refs and therefore memory leaks
res.orig = None
res = nres
return res
def __getitem__(self,i): return self.layers[i]
def append(self,l): return self.layers.append(l)
def extend(self,l): return self.layers.extend(l)
def insert(self,i,l): return self.layers.insert(i,l)
This is useful to write layers that require to remember the input (like a resnet block) in a sequential way.
# export
class MergeLayer(Module):
"Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`."
def __init__(self, dense:bool=False): self.dense=dense
def forward(self, x): return torch.cat([x,x.orig], dim=1) if self.dense else (x+x.orig)
res_block = SequentialEx(ConvLayer(16, 16), ConvLayer(16,16))
res_block.append(MergeLayer()) # just to test append - normally it would be in init params
x = torch.randn(32, 16, 8, 8)
y = res_block(x)
test_eq(y.shape, [32, 16, 8, 8])
test_eq(y, x + res_block[1](res_block[0](x)))
Equivalent to keras.layers.Concatenate, it will concat the outputs of a ModuleList over a given dimesion (default the filter dimesion)
#export
class Cat(nn.ModuleList):
"Concatenate layers outputs over a given dim"
def __init__(self, layers, dim=1):
self.dim=dim
super().__init__(layers)
def forward(self, x): return torch.cat([l(x) for l in self], dim=self.dim)
layers = [ConvLayer(2,4), ConvLayer(2,4), ConvLayer(2,4)]
x = torch.rand(1,2,8,8)
cat = Cat(layers)
test_eq(cat(x).shape, [1,12,8,8])
test_eq(cat(x), torch.cat([l(x) for l in layers], dim=1))
# export
class SimpleCNN(nn.Sequential):
"Create a simple CNN with `filters`."
def __init__(self, filters, kernel_szs=None, strides=None, bn=True):
nl = len(filters)-1
kernel_szs = ifnone(kernel_szs, [3]*nl)
strides = ifnone(strides , [2]*nl)
layers = [ConvLayer(filters[i], filters[i+1], kernel_szs[i], stride=strides[i],
norm_type=(NormType.Batch if bn and i<nl-1 else None)) for i in range(nl)]
layers.append(PoolFlatten())
super().__init__(*layers)
The model is a succession of convolutional layers from (filters[0],filters[1])
to (filters[n-2],filters[n-1])
(if n
is the length of the filters
list) followed by a PoolFlatten
. kernel_szs
and strides
defaults to a list of 3s and a list of 2s. If bn=True
the convolutional layers are successions of conv-relu-batchnorm, otherwise conv-relu.
tst = SimpleCNN([8,16,32])
mods = list(tst.children())
test_eq(len(mods), 3)
test_eq([[m[0].in_channels, m[0].out_channels] for m in mods[:2]], [[8,16], [16,32]])
Test kernel sizes
tst = SimpleCNN([8,16,32], kernel_szs=[1,3])
mods = list(tst.children())
test_eq([m[0].kernel_size for m in mods[:2]], [(1,1), (3,3)])
Test strides
tst = SimpleCNN([8,16,32], strides=[1,2])
mods = list(tst.children())
test_eq([m[0].stride for m in mods[:2]], [(1,1),(2,2)])
# export
class ResBlock(nn.Module):
"Resnet block from `ni` to `nh` with `stride`"
@delegates(ConvLayer.__init__)
def __init__(self, expansion, ni, nh, stride=1, sa=False, sym=False,
norm_type=NormType.Batch, act_cls=defaults.activation, ndim=2, **kwargs):
super().__init__()
norm2 = NormType.BatchZero if norm_type==NormType.Batch else norm_type
nf,ni = nh*expansion,ni*expansion
layers = [ConvLayer(ni, nh, 3, stride=stride, norm_type=norm_type, act_cls=act_cls, ndim=ndim, **kwargs),
ConvLayer(nh, nf, 3, norm_type=norm2, act_cls=None, ndim=ndim, **kwargs)
] if expansion == 1 else [
ConvLayer(ni, nh, 1, norm_type=norm_type, act_cls=act_cls, ndim=ndim, **kwargs),
ConvLayer(nh, nh, 3, stride=stride, norm_type=norm_type, act_cls=act_cls, ndim=ndim, **kwargs),
ConvLayer(nh, nf, 1, norm_type=norm2, act_cls=None, ndim=ndim, **kwargs)
]
self.convs = nn.Sequential(*layers)
self.sa = SimpleSelfAttention(nf,ks=1,sym=sym) if sa else noop
self.idconv = noop if ni==nf else ConvLayer(ni, nf, 1, act_cls=None, ndim=ndim, **kwargs)
self.pool = noop if stride==1 else AvgPool(2, ndim=ndim, ceil_mode=True)
self.act = defaults.activation(inplace=True) if act_cls is defaults.activation else act_cls()
def forward(self, x): return self.act(self.sa(self.convs(x)) + self.idconv(self.pool(x)))
This is a resnet block (normal or bottleneck depending on expansion
, 1 for the normal block and 4 for the traditional bottleneck) that implements the tweaks from Bag of Tricks for Image Classification with Convolutional Neural Networks. In particular, the last batchnorm layer (if that is the selected norm_type
) is initialized with a weight (or gamma) of zero to facilitate the flow from the beginning to the end of the network.
The kwargs
are passed to ConvLayer
along with norm_type
.
#export
from torch.jit import script
@script
def _swish_jit_fwd(x): return x.mul(torch.sigmoid(x))
@script
def _swish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
class _SwishJitAutoFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return _swish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_variables[0]
return _swish_jit_bwd(x, grad_output)
#export
def swish(x, inplace=False): return _SwishJitAutoFn.apply(x)
#export
class Swish(Module):
def forward(self, x): return _SwishJitAutoFn.apply(x)
#export
@script
def _mish_jit_fwd(x): return x.mul(torch.tanh(F.softplus(x)))
@script
def _mish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
x_tanh_sp = F.softplus(x).tanh()
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
class MishJitAutoFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return _mish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_variables[0]
return _mish_jit_bwd(x, grad_output)
#export
def mish(x): return MishJitAutoFn.apply(x)
#export
class MishJit(Module):
def forward(self, x): return MishJitAutoFn.apply(x)
It's easy to get the list of all parameters of a given model. For when you want all submodules (like linear/conv layers) without forgetting lone parameters, the following class wraps those in fake modules.
# export
class ParameterModule(Module):
"Register a lone parameter `p` in a module."
def __init__(self, p): self.val = p
def forward(self, x): return x
# export
def children_and_parameters(m):
"Return the children of `m` and its direct parameters not registered in modules."
children = list(m.children())
children_p = sum([[id(p) for p in c.parameters()] for c in m.children()],[])
for p in m.parameters():
if id(p) not in children_p: children.append(ParameterModule(p))
return children
# export
class TstModule(Module):
def __init__(self): self.a,self.lin = nn.Parameter(torch.randn(1)),nn.Linear(5,10)
tst = TstModule()
children = children_and_parameters(tst)
test_eq(len(children), 2)
test_eq(children[0], tst.lin)
assert isinstance(children[1], ParameterModule)
test_eq(children[1].val, tst.a)
#export
def _has_children(m:nn.Module):
try: next(m.children())
except StopIteration: return False
return True
nn.Module.has_children = property(_has_children)
class A(Module): pass
assert not A().has_children
assert TstModule().has_children
# export
def flatten_model(m):
"Return the list of all submodules and parameters of `m`"
return sum(map(flatten_model,children_and_parameters(m)),[]) if m.has_children else [m]
tst = nn.Sequential(TstModule(), TstModule())
children = flatten_model(tst)
test_eq(len(children), 4)
assert isinstance(children[1], ParameterModule)
assert isinstance(children[3], ParameterModule)
#export
class NoneReduce():
"A context manager to evaluate `loss_func` with none reduce."
def __init__(self, loss_func): self.loss_func,self.old_red = loss_func,None
def __enter__(self):
if hasattr(self.loss_func, 'reduction'):
self.old_red = self.loss_func.reduction
self.loss_func.reduction = 'none'
return self.loss_func
else: return partial(self.loss_func, reduction='none')
def __exit__(self, type, value, traceback):
if self.old_red is not None: self.loss_func.reduction = self.old_red
x,y = torch.randn(5),torch.randn(5)
loss_fn = nn.MSELoss()
with NoneReduce(loss_fn) as loss_func:
loss = loss_func(x,y)
test_eq(loss.shape, [5])
test_eq(loss_fn.reduction, 'mean')
loss_fn = F.mse_loss
with NoneReduce(loss_fn) as loss_func:
loss = loss_func(x,y)
test_eq(loss.shape, [5])
test_eq(loss_fn, F.mse_loss)
#export
def in_channels(m):
"Return the shape of the first weight layer in `m`."
for l in flatten_model(m):
if hasattr(l, 'weight'): return l.weight.shape[1]
raise Exception('No weight layer')
test_eq(in_channels(nn.Sequential(nn.Conv2d(5,4,3), nn.Conv2d(4,3,3))), 5)
test_eq(in_channels(nn.Sequential(nn.AvgPool2d(4), nn.Conv2d(4,3,3))), 4)
test_fail(lambda : in_channels(nn.Sequential(nn.AvgPool2d(4))))
#hide
from local.notebook.export import *
notebook2script(all_fs=True)
Converted 00_test.ipynb. Converted 01_core_foundation.ipynb. Converted 01a_core_utils.ipynb. Converted 01b_core_dispatch.ipynb. Converted 01c_core_transform.ipynb. Converted 02_core_script.ipynb. Converted 03_torchcore.ipynb. Converted 03a_layers.ipynb. Converted 04_data_load.ipynb. Converted 05_data_core.ipynb. Converted 06_data_transforms.ipynb. Converted 07_data_block.ipynb. Converted 08_vision_core.ipynb. Converted 09_vision_augment.ipynb. Converted 09a_vision_data.ipynb. Converted 09b_vision_utils.ipynb. Converted 10_pets_tutorial.ipynb. Converted 11_vision_models_xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_learner.ipynb. Converted 13a_metrics.ipynb. Converted 14_callback_schedule.ipynb. Converted 14a_callback_data.ipynb. Converted 15_callback_hook.ipynb. Converted 15a_vision_models_unet.ipynb. Converted 16_callback_progress.ipynb. Converted 17_callback_tracker.ipynb. Converted 18_callback_fp16.ipynb. Converted 19_callback_mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision_learner.ipynb. Converted 22_tutorial_imagenette.ipynb. Converted 23_tutorial_transfer_learning.ipynb. Converted 30_text_core.ipynb. Converted 31_text_data.ipynb. Converted 32_text_models_awdlstm.ipynb. Converted 33_text_models_core.ipynb. Converted 34_callback_rnn.ipynb. Converted 35_tutorial_wikitext.ipynb. Converted 36_text_models_qrnn.ipynb. Converted 37_text_learner.ipynb. Converted 38_tutorial_ulmfit.ipynb. Converted 40_tabular_core.ipynb. Converted 41_tabular_model.ipynb. Converted 42_tabular_rapids.ipynb. Converted 50_data_block_examples.ipynb. Converted 60_medical_imaging.ipynb. Converted 65_medical_text.ipynb. Converted 70_callback_wandb.ipynb. Converted 71_callback_tensorboard.ipynb. Converted 90_notebook_core.ipynb. Converted 91_notebook_export.ipynb. Converted 92_notebook_showdoc.ipynb. Converted 93_notebook_export2html.ipynb. Converted 94_notebook_test.ipynb. Converted 95_index.ipynb. Converted 96_data_external.ipynb. Converted 97_utils_test.ipynb. Converted notebook2jekyll.ipynb. Converted xse_resnext.ipynb.