Manifold Mixup Callback in Torchbearer

This notebook will cover how to use the manifold mixup callback and what model design considerations there are to make full use of it.

Manifold mixup is a recent progression of the Mixup regulariser which is covered in the regularisers notebook. The basic premise of Mixup is that you can linearly combine two images and their targets ("mixing them up") and achieve a strong regularising effect on the model. Manifold mixup takes this further by arguing that we need not limit ourselves to mixing up just inputs, we can also mixup the features output by individual layers.

Note: The easiest way to use this tutorial is as a colab notebook, which allows you to dive in with no setup. We recommend you enable a free GPU with

Runtime   →   Change runtime type   →   Hardware Accelerator: GPU

Install Torchbearer

First we install torchbearer if needed.

In [2]:
try:
    import torchbearer
except:
    !pip install -q torchbearer
    import torchbearer

print(torchbearer.__version__)
0.5.1.dev

Data

For this example we shall use the CIFAR10 dataset since it is easily available through torchvision.

In [3]:
import torch
from torchvision import datasets, transforms
from torchbearer.cv_utils import DatasetValidationSplitter

transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                   ])
BATCH_SIZE = 128
dataset = datasets.CIFAR10('./data/cifar', train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root='./data/cifar', train=False, download=True, transform=transform)

splitter = DatasetValidationSplitter(len(dataset), 0.1)
trainset = splitter.get_train_dataset(dataset)
valset = splitter.get_val_dataset(dataset)

traingen = torch.utils.data.DataLoader(trainset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=10)
valgen = torch.utils.data.DataLoader(valset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=10)
testgen = torch.utils.data.DataLoader(testset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=False, num_workers=10)
Files already downloaded and verified
Files already downloaded and verified

Model

We take the same model as the quickstart example. We will discuss later how well suited this model is for the manifold mixup callback.

In [4]:
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.convs = nn.Sequential(
            nn.Conv2d(3, 16, stride=2, kernel_size=3),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 32, stride=2, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, stride=2, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )

        self.classifier = nn.Linear(576, 10)

    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, 576)
        return self.classifier(x)


model = SimpleModel()

Manifold Mixup Callback

The ManifoldMixup callback uses a similar fluent API to the Trial class and allows for selecting of layers based on their names in the module tree, their depth or by filtering out by types. Similar to the Mixup callback we can set a lambda or sample it randomly each time from a beta distribution. For this example we will use the default parameters which samples lambda uniformly between 0 and 1.

Looking back at the model definition, we can see that most of the operations are stored as submodules in a sequential block. This is considered depth 1 in the module tree (top level 0 being the sequential block and the classifier). As such we will limit the mixup to depth 1 modules and filter out the ReLU and batch norm layers.

We can also quickly check which layers will be found by calling get_selected_layers and providing the model.

In [5]:
from torchbearer.callbacks.manifold_mixup import ManifoldMixup

mm = ManifoldMixup().at_depth(1).with_layer_type_filter([nn.BatchNorm2d, nn.ReLU])
mm.get_selected_layers(model)
['convs_0', 'convs_3', 'convs_6']

Running a Trial

Lets run a trial on CIFAR10 for a similar duration as in the quickstart.

In [6]:
import torch.optim as optim
from torchbearer.callbacks.mixup import Mixup

device = 'cuda' if torch.cuda.is_available() else 'cpu'
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
loss = Mixup.mixup_loss

from torchbearer import Trial
trial = Trial(model, optimizer, loss, metrics=['acc', 'loss'], callbacks=[mm]).to(device)
trial.with_generators(train_generator=traingen, val_generator=valgen, test_generator=testgen)
history = trial.run(epochs=5, verbose=1)
print(trial.evaluate(data_key=torchbearer.TEST_DATA))
{'test_mixup_acc': 0.5652999877929688, 'test_loss': 1.5160434246063232}

We should see a noticable regularising effect when compared to the baseline from the quickstart which achieved a test accuracy (when using mixup loss categorical accuracy is reported under test_mixup_acc) of around 66%.

Building Models for Manifold Mixup

The manifold mixup callback works by recursively searching the modules and submodules of a model to located the desired layers and then wrapping the forward passes of these to mixup the output when randomly chosen. This means that any operations that are not performed by modules (such as any using the functional interface) cannot be tracked by the callback or mixed up.

Looking back a the model we defined earlier, we can see that all our operations are performed by modules so we should be able to locate every layer. Lets quickly check this and see what these layers are called.

In [12]:
mm = ManifoldMixup().at_depth(None)
mm.get_selected_layers(model)
Out[12]:
['convs',
 'convs_0',
 'convs_1',
 'convs_2',
 'convs_3',
 'convs_4',
 'convs_5',
 'convs_6',
 'convs_7',
 'convs_8',
 'classifier']

We can see that whilst we pick up the right number of layers, 9 from the sequential and the classifier layer, we also find one extra layer named 'convs'. Since the sequential block is a module in itself, we can wrap this as a whole as as well as wrapping any submodules. In reality this isn't particularly useful since the output of the sequential layer is exactly the output of the last batch norm layer.

We can also see that the names we get out are not particularly informative. These names are based upon the registered names in the PyTorch module. As such, when using a sequential block, since we don't define it ourselves (such as by doing self.conv1 = nn.Conv2d(...), it just gets a generic number based on its position in the sequential.

We can avoid this by defining the blocks individually and avoiding the sequential as we will now demonstrate.

In [10]:
class ConvModule(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=2)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))
    
class BetterModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_mod1 = ConvModule(3, 16, 3)
        self.conv_mod2 = ConvModule(16, 32, 3)
        self.conv_mod3 = ConvModule(32, 64, 3)

        self.classifier = nn.Linear(576, 10)
    
    def forward(self, x):
        x = self.conv_mod1(x)
        x = self.conv_mod2(x)
        x = self.conv_mod3(x)
        
        x = x.view(-1, 576)
        return self.classifier(x)
    

If we now look at the layer names of this model, we get a much better idea of what they represent.

In [11]:
mm.get_selected_layers(BetterModel())
Out[11]:
['conv_mod1',
 'conv_mod1_conv',
 'conv_mod1_bn',
 'conv_mod1_relu',
 'conv_mod2',
 'conv_mod2_conv',
 'conv_mod2_bn',
 'conv_mod2_relu',
 'conv_mod3',
 'conv_mod3_conv',
 'conv_mod3_bn',
 'conv_mod3_relu',
 'classifier']