Fitting a toy distribution with a GAN

Implementations of the vanilla GAN [1], least-squares GAN [2], Wasserstein GAN [3] (GP version [4]), and the Hinge-loss GAN [5]

[1] Goodfellow, Ian, et al. "Generative adversarial nets." 
    Advances in neural information processing systems. 2014.
[2] Mao, Xudong, et al. "Least squares generative adversarial networks." 
    Proceedings of the IEEE international conference on computer vision. 2017.
[3] Arjovsky, Martin, Soumith Chintala, and Léon Bottou. "Wasserstein 
    generative adversarial networks." Proceedings of the 34th International 
    Conference on Machine Learning-Volume 70. 2017.
[4] Gulrajani, Ishaan, et al. "Improved training of Wasserstein GANs." 
    Advances in neural information processing systems. 2017.
[5] Lim, Jae Hyun, and Jong Chul Ye. "Geometric GAN."
    arXiv preprint arXiv:1705.02894 (2017).
[6] Zhang, Han, et al. "Self-attention generative adversarial networks." 
    International Conference on Machine Learning. PMLR, 2019.

Setup notebook

In [1]:
from typing import *

from functools import partial
from glob import glob
import math
import os
import random
import sys

gpu_id = 0
os.environ["CUDA_VISIBLE_DEVICES"] = f'{gpu_id}'

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_style('white')

import torch
from torch import nn
import torch.distributions as D
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

Report versions

In [2]:
print('numpy version: {}'.format(np.__version__))
from matplotlib import __version__ as mplver
print('matplotlib version: {}'.format(mplver))
print(f'pytorch version: {torch.__version__}')
numpy version: 1.19.1
matplotlib version: 3.3.1
pytorch version: 1.6.0
In [3]:
pv = sys.version_info
print('python version: {}.{}.{}'.format(pv.major, pv.minor, pv.micro))
python version: 3.8.5

Check GPU(s)

In [4]:
!nvidia-smi | head -n 4
Sun Sep 20 14:02:12 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.40       Driver Version: 430.40       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
In [5]:
assert torch.cuda.is_available()
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True

Set seeds for better reproducibility. See this note before using multiprocessing.

In [6]:
seed = 9
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

Create dataset

In [7]:
loc = torch.tensor([0.,0.])
scale = torch.tensor([1.,1.])
base_dist = D.Normal(loc, scale)
In [8]:
mix_comp = D.Bernoulli(torch.tensor([0.5]))
g1 = D.MultivariateNormal(torch.tensor([-5.,-3.]), 
                          torch.tensor([[2., 0.5],[0.5, 3.]]))
g2 = D.MultivariateNormal(torch.tensor([ 4., 4.]),
                          torch.tensor([[1.,-0.5],[-0.5,2.]]))

def p_data_mix(n_samp:int):
    mix_comps = mix_comp.sample((n_samp,))
    mc0 = (mix_comps < 1)[:,0]
    mc1 = (mix_comps > 0)[:,0]
    samples = torch.zeros(n_samp, 2)
    samples[mc0,:] = g1.sample((mc0.sum(),))
    samples[mc1,:] = g2.sample((mc1.sum(),))
    return samples

def p_data_single(n_samp:int):
    return g2.sample((n_samp,))

sample_p_data = p_data_mix
In [9]:
n_samples = 10000
x_real = sample_p_data(n_samples)
xr, yr = x_real[:,0], x_real[:,1]
In [10]:
g = sns.jointplot(x=xr, y=yr, kind="hex");
g.ax_marg_x.set_title('True distribution');
plt.savefig('true_dist.svg')
In [11]:
xlim = g.ax_joint.get_xlim()
ylim = g.ax_joint.get_ylim()

Define a generator and a discriminator

In [12]:
def G_layer_BN(in_c:int, out_c:int):
    return nn.Sequential(
            nn.Linear(in_c, out_c, bias=False),
            nn.BatchNorm1d(out_c),
            nn.LeakyReLU(inplace=True))

def G_layer_SN(in_c:int, out_c:int):
    return nn.Sequential(
            nn.utils.spectral_norm(nn.Linear(in_c, out_c)),
            nn.LeakyReLU(inplace=True))

class Generator(nn.Sequential):
    _layer = staticmethod(G_layer_SN)
    def __init__(self, in_dim:int, out_dim:int, n_layers:int=5, 
                 hidden_dim:int=128, dropout_rate:float=0.):
        super().__init__()
        self.add_module('h1', self._layer(in_dim, hidden_dim))
        for i in range(2, n_layers):
            self.add_module(f'h{i}', self._layer(hidden_dim, hidden_dim))
        if dropout_rate > 0.:
            self.add_module(f'dropout', nn.Dropout(dropout_rate))
        self.add_module(f'h{n_layers}', nn.Linear(hidden_dim, out_dim))

def D_layer_BN(in_c:int, out_c:int):
    return nn.Sequential(
            nn.Linear(in_c, out_c, bias=False),
            nn.BatchNorm1d(out_c),
            nn.LeakyReLU(0.2, inplace=True))

def D_layer_SN(in_c:int, out_c:int):
    return nn.Sequential(
            nn.utils.spectral_norm(nn.Linear(in_c, out_c)),
            nn.LeakyReLU(0.1, inplace=True))

class Discriminator(Generator):
    _layer = staticmethod(D_layer_SN)
In [13]:
hidden_dim = 128
n_layers = 4
dropout_rate = 0.
g_args = (2, 2, n_layers, hidden_dim, dropout_rate)
d_args = (2, 1, n_layers, hidden_dim, dropout_rate)
generator = Generator(*g_args).to(device)
discriminator = Discriminator(*d_args).to(device)
In [14]:
D_final_activation = None
if D_final_activation == 'tanh':
    discriminator.add_module('tanh', nn.Tanh())
elif D_final_activation == 'sigmoid':
    discriminator.add_module('sigmoid', nn.Sigmoid())
In [15]:
def num_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
In [16]:
print(f'Number of trainable parameters in generator: {num_params(generator)}')
print(f'Number of trainable parameters in discriminator: {num_params(discriminator)}')
Number of trainable parameters in generator: 33666
Number of trainable parameters in discriminator: 33537
In [17]:
def weights_init(m):
    name = m.__class__.__name__
    if 'Linear' in name or 'BatchNorm' in name:
        nn.init.normal_(m.weight.data, 0., 0.02)
        if hasattr(m, 'bias'):
            if m.bias is not None:
                nn.init.constant_(m.bias.data, 0.)

generator.apply(weights_init);
discriminator.apply(weights_init);
In [18]:
def gradient_penalty(y, x):
    weight = torch.ones_like(y)
    grad = torch.autograd.grad(outputs=y,
                               inputs=x,
                               grad_outputs=weight,
                               retain_graph=True,
                               create_graph=True,
                               only_inputs=True)[0]
    return torch.mean((grad.norm(dim=1) - 1) ** 2)

Train the generator and discriminator

In [19]:
n_epochs = 5000
n_samples_per_epoch = n_samples
print_loss_rate = 500
batch_size = n_samples_per_epoch
use_minibatches = batch_size != n_samples_per_epoch
In [20]:
x_real = x_real.to(device)
In [21]:
gan_type = 'hinge'
betas = (0.,0.999)  # parameters from self-attention GAN [6]
G_lr = 1e-4
D_lr = 4e-4
D_steps = 1
use_gp = False
gp_weight = 10.
G_opt = torch.optim.Adam(generator.parameters(), lr=G_lr, betas=betas)
D_opt = torch.optim.Adam(discriminator.parameters(), lr=D_lr, betas=betas)
In [22]:
real_label = 1.
fake_label = 0.
real_labels = real_label * torch.ones((n_samples_per_epoch, 1))
fake_labels = fake_label * torch.ones((n_samples_per_epoch, 1))
real_labels = real_labels.to(device)
fake_labels = fake_labels.to(device)
In [23]:
if use_minibatches:
    MyDataLoader = partial(DataLoader, batch_size=batch_size, shuffle=True)
    x_real_dataset = TensorDataset(x_real)
    x_real_dataloader = MyDataLoader(x_real_dataset)
In [24]:
def reset_grad():
    D_opt.zero_grad()
    G_opt.zero_grad()
In [25]:
def train_discriminator(x_real, z):
    for _ in range(D_steps):
        reset_grad()
        
        # discriminate real samples
        D_real = discriminator(x_real)
        if gan_type == 'vanilla':
            loss_real = F.binary_cross_entropy_with_logits(D_real, real_labels)
            D_x = torch.sigmoid(D_real).mean().item()
        elif gan_type == 'lsgan':
            loss_real = torch.mean((D_real - real_label)**2)
            D_x = D_real.mean().item()
        elif gan_type == 'wgan-gp':
            loss_real = D_real.mean()
            D_x = loss_real.item()
        elif gan_type == 'hinge':
            loss_real = F.relu(1. - D_real).mean()
            D_x = loss_real.item()
        else:
            raise NotImplementedError(f'{gan_type} not implemented.')

        # discriminate fake samples
        with torch.no_grad():
            x_fake = generator(z)
        D_fake = discriminator(x_fake)
        if gan_type == 'vanilla':
            loss_fake = F.binary_cross_entropy_with_logits(D_fake, fake_labels)
            D_G_z_1 = torch.sigmoid(D_fake).mean().item()
        elif gan_type == 'lsgan':
            loss_fake = torch.mean((D_fake - fake_label)**2)
            D_G_z_1 = D_fake.mean().item()
        elif gan_type == 'wgan-gp':
            loss_fake = D_fake.mean()
            D_G_z_1 = loss_fake.item()
        elif gan_type == 'hinge':
            loss_fake = F.relu(1. + D_fake).mean()
            D_G_z_1 = loss_fake.item()
        else:
            raise NotImplementedError(f'{gan_type} not implemented.')
        
        if use_gp or gan_type == 'wgan-gp':
            eps = torch.rand(batch_size,1).to(device)
            x_hat = (eps*x_real + (1.-eps)*x_fake)
            x_hat.requires_grad_(True)
            D_x_hat = discriminator(x_hat)
            gp = gradient_penalty(D_x_hat, x_hat)

        if gan_type != 'wgan-gp':
            D_loss = 0.5 * (loss_fake + loss_real)
            if use_gp:
                D_loss += gp_weight * gp
        else:
            D_loss = loss_fake - loss_real + gp_weight * gp

        D_loss.backward()
        D_opt.step()
    
    return D_loss.item(), D_x, D_G_z_1

def train_generator(z):
    reset_grad()
    x_fake = generator(z)
    D_fake = discriminator(x_fake)
    if gan_type == 'vanilla':
        G_loss = F.binary_cross_entropy_with_logits(D_fake, real_labels)
        D_G_z_2 = torch.sigmoid(D_fake).mean().item()
    elif gan_type == 'lsgan':
        G_loss = 0.5 * torch.mean(D_fake**2)
        D_G_z_2 = D_fake.mean().item()
    elif gan_type == 'wgan-gp':
        G_loss = -D_fake.mean()
        D_G_z_2 = G_loss.item()
    elif gan_type == 'hinge':
        G_loss = -D_fake.mean()
        D_G_z_2 = G_loss.item()
    else:
        raise NotImplementedError(f'{gan_type} not implemented.')

    G_loss.backward()
    G_opt.step()
    
    return G_loss.item(), D_G_z_2
In [26]:
for i in range(1, n_epochs+1):
    z_full = base_dist.sample((n_samples_per_epoch,)).to(device)
    if use_minibatches:
        z_dataset = TensorDataset(z_full)
        z_loader = MyDataLoader(z_dataset)
    
    # train discriminator
    generator.eval();
    discriminator.train();
    
    if use_minibatches:
        for x_r, z in zip(x_real_dataloader, z_loader):
            x_r, z = x_r[0], z[0]
            D_loss, D_x, D_G_z_1 = train_discriminator(x_r, z)
    else:
        D_loss, D_x, D_G_z_1 = train_discriminator(x_real, z_full)
            
    # train generator
    generator.train();
    discriminator.eval();
    
    if use_minibatches:
        for z in z_loader:
            z = z[0]
            G_loss, D_G_z_2 = train_generator(z)
    else:
        G_loss, D_G_z_2 = train_generator(z_full)

    if i % print_loss_rate == 0:
        print(f'Epoch {i}: D_loss={D_loss:0.4f}, G_loss={G_loss:0.4f}, '
              f'D(x): {D_x:0.3f}, D(G(z)): {D_G_z_1:0.3f}/{D_G_z_2:0.3f}')
Epoch 500: D_loss=0.3533, G_loss=1.0013, D(x): 0.528, D(G(z)): 0.179/1.001
Epoch 1000: D_loss=0.4411, G_loss=0.9532, D(x): 0.762, D(G(z)): 0.120/0.953
Epoch 1500: D_loss=0.4974, G_loss=0.9421, D(x): 0.969, D(G(z)): 0.026/0.942
Epoch 2000: D_loss=0.4967, G_loss=0.9665, D(x): 0.989, D(G(z)): 0.004/0.967
Epoch 2500: D_loss=0.4981, G_loss=0.9756, D(x): 0.995, D(G(z)): 0.002/0.976
Epoch 3000: D_loss=0.4981, G_loss=0.9996, D(x): 0.991, D(G(z)): 0.006/1.000
Epoch 3500: D_loss=0.4980, G_loss=0.9862, D(x): 0.995, D(G(z)): 0.001/0.986
Epoch 4000: D_loss=0.4985, G_loss=0.9906, D(x): 0.996, D(G(z)): 0.001/0.991
Epoch 4500: D_loss=0.4981, G_loss=1.0054, D(x): 0.990, D(G(z)): 0.006/1.005
Epoch 5000: D_loss=0.4988, G_loss=1.0052, D(x): 0.990, D(G(z)): 0.007/1.005
In [27]:
z = base_dist.sample((n_samples_per_epoch,)).to(device)
generator.eval()
with torch.no_grad():
    x_fake = generator(z).detach().cpu().numpy()
xf, yf = x_fake[:,0], x_fake[:,1]
In [28]:
sns.jointplot(x=xr, y=yr, kind="hex");
g = sns.jointplot(x=xf, y=yf, kind="hex", xlim=xlim, ylim=ylim);
g.ax_marg_x.set_title('Fit distribution');
plt.savefig('fit_dist.svg')