#hide
#skip
! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab
# default_exp losses
# default_cls_lvl 3
#export
from fastai.imports import *
from fastai.torch_imports import *
from fastai.torch_core import *
from fastai.layers import *
#hide
from nbdev.showdoc import *
Custom fastai loss functions
# export
class BaseLoss():
"Same as `loss_cls`, but flattens input and target."
activation=decodes=noops
def __init__(self, loss_cls, *args, axis=-1, flatten=True, floatify=False, is_2d=True, **kwargs):
store_attr("axis,flatten,floatify,is_2d")
self.func = loss_cls(*args,**kwargs)
functools.update_wrapper(self, self.func)
def __repr__(self): return f"FlattenedLoss of {self.func}"
@property
def reduction(self): return self.func.reduction
@reduction.setter
def reduction(self, v): self.func.reduction = v
def _contiguous(self,x):
return TensorBase(x.transpose(self.axis,-1).contiguous()) if isinstance(x,torch.Tensor) else x
def __call__(self, inp, targ, **kwargs):
inp,targ = map(self._contiguous, (inp,targ))
if self.floatify and targ.dtype!=torch.float16: targ = targ.float()
if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
def to(self, device):
if isinstance(self.func, nn.Module): self.func.to(device)
Wrapping a general loss function inside of BaseLoss
provides extra functionalities to your loss functions:
axis
at the end)activation
method that tells the library if there is an activation fused in the loss (useful for inference and methods such as Learner.get_preds
or Learner.predict
)decodes
method that is used on predictions in inference (for instance, an argmax in classification)The args
and kwargs
will be passed to loss_cls
during the initialization to instantiate a loss function. axis
is put at the end for losses like softmax that are often performed on the last axis. If floatify=True
, the targs
will be converted to floats (useful for losses that only accept float targets like BCEWithLogitsLoss
), and is_2d
determines if we flatten while keeping the first dimension (batch size) or completely flatten the input. We want the first for losses like Cross Entropy, and the second for pretty much anything else.
# export
@delegates()
class CrossEntropyLossFlat(BaseLoss):
"Same as `nn.CrossEntropyLoss`, but flattens input and target."
y_int = True
@use_kwargs_dict(keep=True, weight=None, ignore_index=-100, reduction='mean')
def __init__(self, *args, axis=-1, **kwargs): super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
def decodes(self, x): return x.argmax(dim=self.axis)
def activation(self, x): return F.softmax(x, dim=self.axis)
tst = CrossEntropyLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randint(0, 10, (32,5))
#nn.CrossEntropy would fail with those two tensors, but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.CrossEntropyLoss()(output,target))
#Associated activation is softmax
test_eq(tst.activation(output), F.softmax(output, dim=-1))
#This loss function has a decodes which is argmax
test_eq(tst.decodes(output), output.argmax(dim=-1))
#In a segmentation task, we want to take the softmax over the channel dimension
tst = CrossEntropyLossFlat(axis=1)
output = torch.randn(32, 5, 128, 128)
target = torch.randint(0, 5, (32, 128, 128))
_ = tst(output, target)
test_eq(tst.activation(output), F.softmax(output, dim=1))
test_eq(tst.decodes(output), output.argmax(dim=1))
#hide
#cuda
tst = CrossEntropyLossFlat(weight=torch.ones(10))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tst.to(device)
output = torch.randn(32, 10, device=device)
target = torch.randint(0, 10, (32,), device=device)
_ = tst(output, target)
Focal Loss is the same as cross entropy except easy-to-classify observations are down-weighted in the loss calculation. The strength of down-weighting is proportional to the size of the gamma
parameter. Put another way, the larger gamma
the less the easy-to-classify observations contribute to the loss.
# export
class FocalLossFlat(CrossEntropyLossFlat):
"""
Same as CrossEntropyLossFlat but with focal paramter, `gamma`. Focal loss is introduced by Lin et al.
https://arxiv.org/pdf/1708.02002.pdf. Note the class weighting factor in the paper, alpha, can be
implemented through pytorch `weight` argument in nn.CrossEntropyLoss.
"""
y_int = True
@use_kwargs_dict(keep=True, weight=None, ignore_index=-100, reduction='mean')
def __init__(self, *args, gamma=2, axis=-1, **kwargs):
self.gamma = gamma
self.reduce = kwargs.pop('reduction') if 'reduction' in kwargs else 'mean'
super().__init__(*args, reduction='none', axis=axis, **kwargs)
def __call__(self, inp, targ, **kwargs):
ce_loss = super().__call__(inp, targ, **kwargs)
pt = torch.exp(-ce_loss)
fl_loss = (1-pt)**self.gamma * ce_loss
return fl_loss.mean() if self.reduce == 'mean' else fl_loss.sum() if self.reduce == 'sum' else fl_loss
#Compare focal loss with gamma = 0 to cross entropy
fl = FocalLossFlat(gamma=0)
ce = CrossEntropyLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randint(0, 10, (32,5))
test_close(fl(output, target), ce(output, target))
#Test focal loss with gamma > 0 is different than cross entropy
fl = FocalLossFlat(gamma=2)
test_ne(fl(output, target), ce(output, target))
#In a segmentation task, we want to take the softmax over the channel dimension
fl = FocalLossFlat(gamma=0, axis=1)
ce = CrossEntropyLossFlat(axis=1)
output = torch.randn(32, 5, 128, 128)
target = torch.randint(0, 5, (32, 128, 128))
test_close(fl(output, target), ce(output, target), eps=1e-4)
test_eq(fl.activation(output), F.softmax(output, dim=1))
test_eq(fl.decodes(output), output.argmax(dim=1))
# export
@delegates()
class BCEWithLogitsLossFlat(BaseLoss):
"Same as `nn.BCEWithLogitsLoss`, but flattens input and target."
@use_kwargs_dict(keep=True, weight=None, reduction='mean', pos_weight=None)
def __init__(self, *args, axis=-1, floatify=True, thresh=0.5, **kwargs):
if kwargs.get('pos_weight', None) is not None and kwargs.get('flatten', None) is True:
raise ValueError("`flatten` must be False when using `pos_weight` to avoid a RuntimeError due to shape mismatch")
if kwargs.get('pos_weight', None) is not None: kwargs['flatten'] = False
super().__init__(nn.BCEWithLogitsLoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
self.thresh = thresh
def decodes(self, x): return x>self.thresh
def activation(self, x): return torch.sigmoid(x)
tst = BCEWithLogitsLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randn(32, 5, 10)
#nn.BCEWithLogitsLoss would fail with those two tensors, but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))
output = torch.randn(32, 5)
target = torch.randint(0,2,(32, 5))
#nn.BCEWithLogitsLoss would fail with int targets but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))
tst = BCEWithLogitsLossFlat(pos_weight=torch.ones(10))
output = torch.randn(32, 5, 10)
target = torch.randn(32, 5, 10)
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))
#Associated activation is sigmoid
test_eq(tst.activation(output), torch.sigmoid(output))
# export
@use_kwargs_dict(weight=None, reduction='mean')
def BCELossFlat(*args, axis=-1, floatify=True, **kwargs):
"Same as `nn.BCELoss`, but flattens input and target."
return BaseLoss(nn.BCELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
tst = BCELossFlat()
output = torch.sigmoid(torch.randn(32, 5, 10))
target = torch.randint(0,2,(32, 5, 10))
_ = tst(output, target)
test_fail(lambda x: nn.BCELoss()(output,target))
# export
@use_kwargs_dict(reduction='mean')
def MSELossFlat(*args, axis=-1, floatify=True, **kwargs):
"Same as `nn.MSELoss`, but flattens input and target."
return BaseLoss(nn.MSELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
tst = MSELossFlat()
output = torch.sigmoid(torch.randn(32, 5, 10))
target = torch.randint(0,2,(32, 5, 10))
_ = tst(output, target)
test_fail(lambda x: nn.MSELoss()(output,target))
#hide
#cuda
#Test losses work in half precision
if torch.cuda.is_available():
output = torch.sigmoid(torch.randn(32, 5, 10)).half().cuda()
target = torch.randint(0,2,(32, 5, 10)).half().cuda()
for tst in [BCELossFlat(), MSELossFlat()]: _ = tst(output, target)
# export
@use_kwargs_dict(reduction='mean')
def L1LossFlat(*args, axis=-1, floatify=True, **kwargs):
"Same as `nn.L1Loss`, but flattens input and target."
return BaseLoss(nn.L1Loss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
#export
class LabelSmoothingCrossEntropy(Module):
y_int = True
def __init__(self, eps:float=0.1, weight=None, reduction='mean'):
store_attr()
def forward(self, output, target):
c = output.size()[1]
log_preds = F.log_softmax(output, dim=1)
if self.reduction=='sum': loss = -log_preds.sum()
else:
loss = -log_preds.sum(dim=1) #We divide by that size at the return line so sum and not mean
if self.reduction=='mean': loss = loss.mean()
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target.long(), weight=self.weight, reduction=self.reduction)
def activation(self, out): return F.softmax(out, dim=-1)
def decodes(self, out): return out.argmax(dim=-1)
lmce = LabelSmoothingCrossEntropy()
output = torch.randn(32, 5, 10)
target = torch.randint(0, 10, (32,5))
test_close(lmce(output.flatten(0,1), target.flatten()), lmce(output.transpose(-1,-2), target))
On top of the formula we define:
reduction
attribute, that will be used when we call Learner.get_preds
weight
attribute to pass to BCE.activation
function that represents the activation fused in the loss (since we use cross entropy behind the scenes). It will be applied to the output of the model when calling Learner.get_preds
or Learner.predict
decodes
function that converts the output of the model to a format similar to the target (here indices). This is used in Learner.predict
and Learner.show_results
to decode the predictions#export
@delegates()
class LabelSmoothingCrossEntropyFlat(BaseLoss):
"Same as `LabelSmoothingCrossEntropy`, but flattens input and target."
y_int = True
@use_kwargs_dict(keep=True, eps=0.1, reduction='mean')
def __init__(self, *args, axis=-1, **kwargs): super().__init__(LabelSmoothingCrossEntropy, *args, axis=axis, **kwargs)
def activation(self, out): return F.softmax(out, dim=-1)
def decodes(self, out): return out.argmax(dim=-1)
We present a general Dice
loss for segmentation tasks. It is commonly used together with CrossEntropyLoss
or FocalLoss
in kaggle competitions. This is very similar to the DiceMulti
metric, but to be able to derivate through, we replace the argmax
activation by a softmax
and compare this with a one-hot encoded target mask. This function also adds a smooth
parameter to help numerical stabilities in the intersection over union division.
#export
class DiceLoss:
"Dice loss for segmentation"
def __init__(self, axis=1, smooth=1):
store_attr()
def __call__(self, pred, targ):
targ = self._one_hot(targ, pred.shape[self.axis])
pred, targ = flatten_check(self.activation(pred), targ)
inter = (pred*targ).sum()
union = (pred+targ).sum()
return 1 - (2. * inter + self.smooth)/(union + self.smooth)
@staticmethod
def _one_hot(x, classes, axis=1):
"Creates one binay mask per class"
return torch.stack([torch.where(x==c, 1, 0) for c in range(classes)], axis=axis)
def activation(self, x): return F.softmax(x, dim=self.axis)
def decodes(self, x): return x.argmax(dim=self.axis)
dl = DiceLoss()
_x = tensor( [[[1, 0, 2],
[2, 2, 1]]])
_one_hot_x = tensor([[[[0, 1, 0],
[0, 0, 0]],
[[1, 0, 0],
[0, 0, 1]],
[[0, 0, 1],
[1, 1, 0]]]])
test_eq(dl._one_hot(_x, 3), _one_hot_x)
dl = DiceLoss()
model_output = tensor([[[[2., 1.],
[1., 5.]],
[[1, 2.],
[3., 1.]],
[[3., 0],
[4., 3.]]]])
target = tensor([[[2, 1],
[2, 0]]])
dl_out = dl(model_output, target)
test_eq(dl.decodes(model_output), target)
dl = DiceLoss(smooth=0.)
#identical masks
model_output = tensor([[[.1], [.1], [100.]]])
target = tensor([[2]])
test_close(dl(model_output, target), 0)
#50% intersection
model_output = tensor([[[.1, 100.], [.1, .1], [100., .1]]])
target = tensor([[2, 1]])
test_close(dl(model_output, target), .5)
#hide
#cuda
#Test DicceLoss work in half precision
if torch.cuda.is_available():
output = torch.randn(32, 4, 5, 10).half().cuda()
target = torch.randint(0,2,(32, 5, 10)).half().cuda()
_ = dl(output, target)
You could easily combine this loss with FocalLoss
defining a CombinedLoss
, to balance between global (Dice) and local (Focal) features on the target mask.
class CombinedLoss:
"Dice and Focal combined"
def __init__(self, axis=1, smooth=1., alpha=1.):
store_attr()
self.focal_loss = FocalLossFlat(axis=axis)
self.dice_loss = DiceLoss(axis, smooth)
def __call__(self, pred, targ):
return self.focal_loss(pred, targ) + self.alpha * self.dice_loss(pred, targ)
def decodes(self, x): return x.argmax(dim=self.axis)
def activation(self, x): return F.softmax(x, dim=self.axis)
cl = CombinedLoss()
output = torch.randn(32, 4, 5, 10)
target = torch.randint(0,2,(32, 5, 10))
_ = cl(output, target)
#hide
from nbdev.export import *
notebook2script()
Converted 00_torch_core.ipynb. Converted 01_layers.ipynb. Converted 01a_losses.ipynb. Converted 02_data.load.ipynb. Converted 03_data.core.ipynb. Converted 04_data.external.ipynb. Converted 05_data.transforms.ipynb. Converted 06_data.block.ipynb. Converted 07_vision.core.ipynb. Converted 08_vision.data.ipynb. Converted 09_vision.augment.ipynb. Converted 09b_vision.utils.ipynb. Converted 09c_vision.widgets.ipynb. Converted 10_tutorial.pets.ipynb. Converted 10b_tutorial.albumentations.ipynb. Converted 11_vision.models.xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_callback.core.ipynb. Converted 13a_learner.ipynb. Converted 13b_metrics.ipynb. Converted 14_callback.schedule.ipynb. Converted 14a_callback.data.ipynb. Converted 15_callback.hook.ipynb. Converted 15a_vision.models.unet.ipynb. Converted 16_callback.progress.ipynb. Converted 17_callback.tracker.ipynb. Converted 18_callback.fp16.ipynb. Converted 18a_callback.training.ipynb. Converted 18b_callback.preds.ipynb. Converted 19_callback.mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision.learner.ipynb. Converted 22_tutorial.imagenette.ipynb. Converted 23_tutorial.vision.ipynb. Converted 24_tutorial.image_sequence.ipynb. Converted 24_tutorial.siamese.ipynb. Converted 24_vision.gan.ipynb. Converted 30_text.core.ipynb. Converted 31_text.data.ipynb. Converted 32_text.models.awdlstm.ipynb. Converted 33_text.models.core.ipynb. Converted 34_callback.rnn.ipynb. Converted 35_tutorial.wikitext.ipynb. Converted 37_text.learner.ipynb. Converted 38_tutorial.text.ipynb. Converted 39_tutorial.transformers.ipynb. Converted 40_tabular.core.ipynb. Converted 41_tabular.data.ipynb. Converted 42_tabular.model.ipynb. Converted 43_tabular.learner.ipynb. Converted 44_tutorial.tabular.ipynb. Converted 45_collab.ipynb. Converted 46_tutorial.collab.ipynb. Converted 50_tutorial.datablock.ipynb. Converted 60_medical.imaging.ipynb. Converted 61_tutorial.medical_imaging.ipynb. Converted 65_medical.text.ipynb. Converted 70_callback.wandb.ipynb. Converted 71_callback.tensorboard.ipynb. Converted 72_callback.neptune.ipynb. Converted 73_callback.captum.ipynb. Converted 74_callback.azureml.ipynb. Converted 97_test_utils.ipynb. Converted 99_pytorch_doc.ipynb. Converted dev-setup.ipynb. Converted app_examples.ipynb. Converted camvid.ipynb. Converted migrating_catalyst.ipynb. Converted migrating_ignite.ipynb. Converted migrating_lightning.ipynb. Converted migrating_pytorch.ipynb. Converted ulmfit.ipynb. Converted index.ipynb. Converted quick_start.ipynb. Converted tutorial.ipynb.