Many of the well known deep learning problems are image classification problems where we expect as input an image and a class integer, but what if we are not classifying? Base Torchbearer expects that the data generator yields a tuple of data and target, but if your problem doesn't follow this structure you need a way to change how Torchbearer loads the data. The easiest way to do this is through the with_loader trial method which takes a function of state and is expected to populate the torchbearer.X and torchbearer.Y_TRUE keys in state. This example will go through a simple implementation of a custom loader using this method.

Note: The easiest way to use this tutorial is as a colab notebook, which allows you to dive in with no setup. We recommend you enable a free GPU with

Runtime   →   Change runtime type   →   Hardware Accelerator: GPU

## Install Torchbearer¶

First we install torchbearer if needed.

In [1]:
try:
import torchbearer
except:
!pip install -q torchbearer
import torchbearer

# If problems arise, try
# !pip install git+https://github.com/pytorchbearer/torchbearer
# import torchbearer

print(torchbearer.__version__)

  Building wheel for torchbearer (setup.py) ... done
0.4.0.dev


## Creating a Dataset¶

We need a simple problem that doesn't quite follow the usual data, target tuple structure. For this example we will do this by zipping together two instances of CIFAR such that our data generator yields ((img1, target1), (image2, target2)).

In the code below we create a simple wrapper class that takes two generators and loads a batch from each of them. Ideally we would have literally just called zip(gen1, gen2), but for this to work it would be required to be loaded completely into memory, which we're trying to avoid. As such the ZipLoader we create has the required methods to be an iterator, making it slightly more complicated.

The important part of this code is that we create a trainloader and a testloader which yield data of the format previously mentioned.

In [2]:
import torch
import torchvision
from torchvision import transforms

BATCH_SIZE = 128

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

# Hack to join two versions of CIFAR together

def __iter__(self):
return self

def __len__(self):

def __next__(self):
return self.next()

def next(self):

traingen = torch.utils.data.DataLoader(trainset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=7)
traingen2 = torch.utils.data.DataLoader(trainset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=7)

testgen = torch.utils.data.DataLoader(testset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=False, num_workers=7)
testgen2 = torch.utils.data.DataLoader(testset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=False, num_workers=7)

Files already downloaded and verified


Now, as mentioned, the with_loader method takes a function of state and is required to fill torchbearer.X and torchbearer.Y_TRUE. In the code below we first show what the standard Torchbearer dataloader looks like. It calls next on the iterator, whilst casting to the right device and data type with deep_to, a function which recursively calls to on the argument. This is how Torchbearer handles device management of data and so it required in our custom loader also.

Looking at the custom loader we see a very similar call that instead of unpacking straight into torchbearer.X and torchbearer.Y_TRUE, we extract the individual parts first. We want our metrics to still work, so we convert into a single class problem. To do this, we concatenate the two images and convert the labels in the following way:

• We take the first label and multiply it by 10 so that the first digit of our 0-99 class labels describes the first class
• We then add the second label to this number so that the second digid describes the second class

So, for example if we had one image from class 4 and the second image from class 8, the label would be 48. Note that these classes are 0 based, so there is no class 10.

In [0]:
from torchbearer import deep_to

state[torchbearer.X], state[torchbearer.Y_TRUE] = deep_to(next(state[torchbearer.ITERATOR]),
state[torchbearer.DEVICE],
state[torchbearer.DATA_TYPE])

(img1, label1), (img2, label2) = deep_to(next(state[torchbearer.ITERATOR]), state[torchbearer.DEVICE], state[torchbearer.DATA_TYPE])
image = torch.cat((img1, img2), 1)
label = 10*label1 + label2
state[torchbearer.X], state[torchbearer.Y_TRUE] = image, label



## Model¶

We now need a model to test with this problem. Recall that in the loader function we catted together the two images and combined the two labels to change the problem from two 10 class problems into one with 100 classes. As such our model has 6 input channels (3 for each RGB image) and 100 outputs, one for each class, as in a usual classifier.

In [0]:
import torch.nn as nn

class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.convs = nn.Sequential(
nn.Conv2d(6, 16, stride=2, kernel_size=3),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, stride=2, kernel_size=3),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, stride=2, kernel_size=3),
nn.BatchNorm2d(64),
nn.ReLU()
)

self.classifier = nn.Linear(576, 100)

def forward(self, x):
x = self.convs(x)
x = x.view(-1, 576)
return self.classifier(x)

model = SimpleModel()


## Trial¶

We now create a simple trial to run this model. We call with_loader with our custom loading function and ask for an accuracy and loss metric.

You'll notice that our accuracy is not particularly great. Treating this as a 100 class problem instead of two 10 class problems means that unless we get both guesses correct, the accuracy metric thinks we got it wrong. This leads to a large drop in accuracy. A stronger model would be required to do well on this problem.

In [5]:
import torch.optim as optim

device = 'cuda' if torch.cuda.is_available() else 'cpu'
loss = nn.CrossEntropyLoss()

import torchbearer
from torchbearer import Trial

trial = Trial(model, optimizer, loss, metrics=['acc', 'loss'], callbacks=[]).to(device)

trial.evaluate(data_key=torchbearer.TEST_DATA)

{'test_acc': 0.3774999976158142, 'test_loss': 2.6687088012695312}