Training a GAN

We shall try to implement something more complicated using torchbearer - a Generative Adverserial Network (GAN). This tutorial is a modified version of the GAN from the brilliant collection of GAN implementations PyTorch_GAN by eriklindernoren on github.

Note: The easiest way to use this tutorial is as a colab notebook, which allows you to dive in with no setup. We recommend you enable a free GPU with

Runtime   →   Change runtime type   →   Hardware Accelerator: GPU

Install Torchbearer

First we install torchbearer if needed.

In [1]:
try:
    import torchbearer
except:
    !pip install -q torchbearer
    import torchbearer
    
print(torchbearer.__version__)
0.3.2

Data and Constants

Lets now define all constants and state keys for the example.

In [2]:
import torch
from torchbearer import state_key

# Define constants
train_steps = 50000
batch_size = 128
lr = 0.002
latent_dim = 100
sample_interval = 400
img_shape = (1, 28, 28)
adversarial_loss = torch.nn.BCELoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
valid = torch.ones(batch_size, 1, device=device)
fake = torch.zeros(batch_size, 1, device=device)
batch = torch.randn(25, latent_dim).to(device)

# Register state keys (optional)
GEN_IMGS = state_key('gen_imgs')
DISC_GEN = state_key('disc_gen')
DISC_GEN_DET = state_key('disc_gen_det')
DISC_REAL = state_key('disc_real')
G_LOSS = state_key('g_loss')
D_LOSS = state_key('d_loss')

DISC_OPT = state_key('disc_opt')
GEN_OPT = state_key('gen_opt')
DISC_MODEL = state_key('disc_model')
DISC_IMGS = state_key('disc_imgs')
DISC_CRIT = state_key('disc_crit')

And the dataset and dataloader, for which we will use MNIST.

In [3]:
from torchvision import datasets, transforms

transform = transforms.Compose([
                        transforms.ToTensor(),
                   ])
dataset = datasets.MNIST('./data/mnist', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

Model

We use the generator and discriminator from PyTorch_GAN.

In [4]:
import torch.nn as nn
import numpy as np

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            nn.Linear(512, int(np.prod(img_shape))),
            nn.Sigmoid()
        )

    def forward(self, real_imgs, state):
        z = torch.randn(real_imgs.shape[0], latent_dim, device = state[torchbearer.DEVICE])
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, state):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

And lets create the models and the optimisers.

In [5]:
generator = Generator()
discriminator = Discriminator()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

Loss

GANs usually require two different losses, one for the generator and one for the discriminator. We define these as functions of state so that we can access the discriminator model for the additional forward passes required. We will see later how we get a torchbearer trial to use these losses.

In [6]:
def gen_crit(state):
    loss =  adversarial_loss(state[DISC_MODEL](state[torchbearer.Y_PRED], state), valid)
    state[G_LOSS] = loss
    return loss


def disc_crit(state):
    real_loss = adversarial_loss(state[DISC_MODEL](state[torchbearer.X], state), valid)
    fake_loss = adversarial_loss(state[DISC_MODEL](state[torchbearer.Y_PRED].detach(), state), fake)
    loss = (real_loss + fake_loss) / 2
    state[D_LOSS] = loss
    return loss

Metrics

We would like to follow the discriminator and generator losses during training - note that we added these to state during the model definition. In torchbearer, state keys are also metrics, so we can take means and running means of them and tell torchbearer to output them as metrics. We will add this metric list to the trial when we create it.

In [7]:
from torchbearer.metrics import mean, running_mean
metrics = ['loss', mean(running_mean(D_LOSS)), mean(running_mean(G_LOSS))]

Closures

The training loop of a GAN is a bit different to a standard model training loop. GANs require separate forward and backward passes for the generator and discriminator. To achieve this in torchbearer we can write a new closure. Since the individual training loops for the generator and discriminator are the same as a standard training loop we can use a base_closure. The base closure takes state keys for required objects (data, model, optimiser, etc.) and returns a standard closure consisting of:

  1. Zero gradients
  2. Forward pass
  3. Loss calculation
  4. Backward pass

We create a separate closure for the generator and discriminator. We use separate state keys for some objects so we can use them separately, although the loss is easier to deal with in a single key.

In [8]:
from torchbearer.bases import base_closure
closure_gen = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, GEN_OPT)
closure_disc = base_closure(torchbearer.Y_PRED, DISC_MODEL, None, DISC_IMGS, DISC_CRIT, torchbearer.LOSS, DISC_OPT)

We now create a main closure (a simple function of state) that runs both of these and steps the optimisers.

In [9]:
def closure(state):
    closure_gen(state)
    state[GEN_OPT].step()
    closure_disc(state)
    state[DISC_OPT].step()

Visualising

We borrow the image saving method from PyTorch_GAN and put it in a callback to save on_step_training. We generate from the same inputs each time to get a better visualisation.

In [10]:
from torchvision.utils import save_image
from torchbearer import callbacks
import os
os.makedirs('images', exist_ok=True)

@callbacks.on_step_training
@callbacks.only_if(lambda state: state[torchbearer.BATCH] % sample_interval == 0)
def saver_callback(state):
    samples = state[torchbearer.MODEL](batch, state)
    save_image(samples, 'images/%d.png' % state[torchbearer.BATCH], nrow=5, normalize=True)

Training

We now create the torchbearer trial on the GPU in the standard way. Note that when torchbearer is passed a None optimiser it creates a mock optimser that will just run the closure. Since we are using the standard torchbearer state keys for the generator model and criterion, we can pass them in here.

In [11]:
trial = torchbearer.Trial(generator, None, criterion=gen_crit, metrics=metrics, callbacks=[saver_callback])
trial.with_train_generator(dataloader, steps=train_steps)
_ = trial.to(device)

We now update state with the keys required for the discriminators closure and add the new closure to the trial. Note that torchbearer doesn’t know the discriminator model is a model here, so we have to sent it to the GPU ourselves.

In [12]:
new_keys = {DISC_MODEL: discriminator.to(device), DISC_OPT: optimizer_D, GEN_OPT: optimizer_G, DISC_CRIT: disc_crit}
trial.state.update(new_keys)
trial.with_closure(closure)
trial.run(epochs=1)

Out[12]:
[((50000, None),
  {'running_loss': 0.32024282217025757,
   'running_d_loss': 0.32024282217025757,
   'running_g_loss': 2.202518939971924,
   'loss': 0.4093643128871918,
   'd_loss': 0.4093643128871918,
   'g_loss': 1.8098881244659424})]

Here is a Gif we made of the results.

gangif