# simple generative adversarial network
# this version uses the celebrity faces dataset celeba (aligned)
# mount drive to access csv files
from google.colab import drive
drive.mount('./my_data')
# conventional PyTorch imports
import torch
import torch.nn as nn
#import torch.nn.functional as F
from torch.utils.data import Dataset
import random
import pandas
import numpy
import matplotlib.pyplot as plt
import h5py
# GPU
#torch.cuda.is_available()
#torch.cuda.get_device_name(0)
if torch.cuda.is_available():
torch.set_default_tensor_type(torch.cuda.FloatTensor)
print("using cuda:", torch.cuda.get_device_name(0))
pass
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
using cuda: Tesla T4
device(type='cuda')
# function to generate uniform random data
# size parameter is length of tensor
def generate_random(size):
#return torch.rand(size)
return torch.randn(size)
generate_random(5)
tensor([ 1.2414, 0.1590, -2.7007, 0.4053, -1.2089])
# crop centre of image, based on https://stackoverflow.com/questions/43463523/center-crop-a-numpy-array
def crop_center(img,cropx,cropy):
y,x,c = img.shape
startx = x//2 - cropx//2
starty = y//2 - cropy//2
return img[starty:starty+cropy, startx:startx+cropx, :]
# crops out 128x128 centre of image
class CelebADataSet(torch.utils.data.Dataset):
def __init__(self, file):
self.fh = h5py.File(file, 'r')
self.dataset = self.fh['img_align_celeba']
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
# image data
imn = list(self.dataset.keys())[index]
img = numpy.array(self.dataset[imn])
# crop square centre
img = crop_center(img, 128, 128)
# rescale to (-1, +1)
img = (img / 127.5) - 1.0
return torch.cuda.FloatTensor(img)
def plot_image(self, index):
imn = list(self.dataset.keys())[index]
img = numpy.array(self.dataset[imn])
# crop square centre
img = crop_center(img, 128, 128)
plt.imshow(img, interpolation='nearest')
pass
def plot_images(self):
# plot a 3 column, 2 row array of sample images
f, axarr = plt.subplots(2,3, figsize=(16,8))
for i in range(2):
for j in range(3):
idx = random.randint(0,len(celeba_dataset))
imn = list(self.dataset.keys())[idx]
img = numpy.array(self.dataset[imn])
# crop square centre
img = crop_center(img, 128, 128)
axarr[i,j].imshow(img, interpolation='nearest')
pass
pass
pass
pass
# subclass PyTorch dataset class, loads actual data, parses it into targets and pizel data
celeba_dataset = CelebADataSet('my_data/My Drive/Colab Notebooks/gan/celeba_dataset/celeba_aligned_small.h5py')
# images are of size cropped to (128,128)
celeba_dataset.plot_images()
# from https://github.com/pytorch/vision/issues/720
class View(nn.Module):
def __init__(self, shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.view(*self.shape)
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# discriminator class
class Discriminator(nn.Module):
def __init__(self):
# initialise parent pytorch class
super().__init__()
# define neural network layers
# input shape is 1, 3, 128, 128
self.model = nn.Sequential(
nn.Conv2d(3, 256, kernel_size=8, stride=2, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Dropout2d(0.3),
nn.Conv2d(256, 256, kernel_size=8, stride=2, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Dropout2d(0.3),
nn.Conv2d(256, 3, kernel_size=8, stride=2, bias=False),
#nn.BatchNorm2d(3),
nn.LeakyReLU(0.2),
nn.Dropout2d(0.3),
View((1,3*10*10)),
nn.Linear(3*10*10, 1),
nn.Sigmoid()
)
# create error function
self.error_function = torch.nn.BCELoss()
# create optimiser, using simple stochastic gradient descent
self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0002, betas=(0.5, 0.999))
#self.optimiser = torch.optim.RMSprop(self.parameters(), lr=0.0002)
# counter and accumulator for progress
self.counter = 0;
self.progress = []
pass
def forward(self, inputs):
# simply run model
return self.model(inputs)
def train(self, inputs, targets):
# calculate the output of the network
outputs = self.forward(inputs)
# calculate error
loss = self.error_function(outputs, targets)
# increase counter and accumulate error every 10
self.counter += 1;
self.progress.append(loss.item())
# zero gradients, perform a backward pass, and update the weights.
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
pass
def save(self, path):
torch.save(self.state_dict(), path)
pass
def load(self, path):
self.load_state_dict(torch.load(path))
#self.eval()
pass
def plot_progress(self):
df = pandas.DataFrame(self.progress, columns=['loss'])
df.plot(ylim=(0, 1.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))
pass
pass
# test Discrominator (error messages useful to get convolution sizes right)
D = Discriminator()
D.to(device)
# initialise weights
D.apply(weights_init)
D.forward(generate_random(3*128*128).view(1,3,128,128)).item()
0.7492928504943848
%%time
# create Discriminator and test it
D = Discriminator()
D.to(device)
# initialise weights
D.apply(weights_init)
# train Discriminator
epochs = 1
for i in range(epochs):
print('training epoch', i+1, "of", epochs)
count = 1
for image_data_tensor in celeba_dataset:
# train discriminator on real data
D.train(image_data_tensor.permute(2,0,1).view(1, 3, 128, 128), torch.cuda.FloatTensor([1.0]).view(1,1))
# train discriminator on false (random) data
D.train(generate_random(3*128*128).view((1, 3, 128, 128)), torch.cuda.FloatTensor([0.0]).view(1,1))
if (count % 1000 == 0):
print("count = ", count)
pass
count += 1
pass
pass
training epoch 1 of 1 count = 1000 count = 2000 count = 3000 count = 4000 count = 5000 count = 6000 count = 7000 count = 8000 count = 9000 count = 10000 count = 11000 count = 12000 count = 13000 count = 14000 count = 15000 count = 16000 count = 17000 count = 18000 count = 19000 CPU times: user 8min 13s, sys: 1min 33s, total: 9min 47s Wall time: 9min 51s
# plot discriminator error
D.plot_progress()
# manually check D can indeed discriminate between real and fake data
print(D.forward(generate_random(3*128*128).view(1, 3, 128, 128)).item())
print(D.forward(generate_random(3*128*128).view(1, 3, 128, 128)).item())
print(D.forward(generate_random(3*128*128).view(1, 3, 128, 128)).item())
print(D.forward(generate_random(3*128*128).view(1, 3, 128, 128)).item())
print(D.forward(celeba_dataset[random.randint(0,len(celeba_dataset))].permute(2,0,1).view(1, 3, 128, 128)).item())
print(D.forward(celeba_dataset[random.randint(0,len(celeba_dataset))].permute(2,0,1).view(1, 3, 128, 128)).item())
print(D.forward(celeba_dataset[random.randint(0,len(celeba_dataset))].permute(2,0,1).view(1, 3, 128, 128)).item())
print(D.forward(celeba_dataset[random.randint(0,len(celeba_dataset))].permute(2,0,1).view(1, 3, 128, 128)).item())
1.908766989799915e-06 5.13494824850284e-16 9.621245315206972e-14 7.653628286696801e-10 1.0 1.0 1.0 1.0
# generator class
class Generator(nn.Module):
def __init__(self):
# initialise parent pytorch class
super().__init__()
# define neural network layers
# input shape is 1-dimensional
self.model = nn.Sequential(
# input is a 1d array
nn.Linear(100, 3*28*28),
nn.LeakyReLU(0.2),
# reshape to 2d
View((1, 3, 28, 28)),
nn.ConvTranspose2d(3, 256, kernel_size=8, stride=2, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(256, 3, kernel_size=8, stride=2, padding=1, bias=False),
nn.BatchNorm2d(3),
nn.LeakyReLU(0.2),
View((1,3,128,128)),
nn.Tanh()
)
# create error function
self.error_function = torch.nn.BCELoss()
# create optimiser, using simple stochastic gradient descent
self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0002, betas=(0.5, 0.999))
#self.optimiser = torch.optim.RMSprop(self.parameters(), lr=0.0002)
# counter and accumulator for progress
self.counter = 0;
self.progress = []
pass
def forward(self, inputs):
# simply run model
return self.model(inputs)
def train(self, D, inputs, targets):
# calculate the output of the network
g_output = self.forward(inputs)
# pass onto Discriminator
d_output = D.forward(g_output)
# calculate error
loss = D.error_function(d_output, targets)
# increase counter and accumulate error every 10
self.counter += 1;
self.progress.append(loss.item())
if (self.counter % 1000 == 0):
print("counter = ", self.counter)
pass
# zero gradients, perform a backward pass, and update the weights.
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
pass
def save(self, path):
torch.save(self.state_dict(), path)
pass
def load(self, path):
self.load_state_dict(torch.load(path))
#self.eval()
pass
def plot_images(self):
# plot a 3 column, 2 row array of sample images
f, axarr = plt.subplots(2,3, figsize=(16,8))
for i in range(2):
for j in range(3):
img = G.forward(generate_random(100)).view(3,128,128).permute(1,2,0).detach().cpu().numpy()
img = (img + 1.0) / 2.0
axarr[i,j].imshow(img, interpolation='none')
pass
pass
pass
def plot_progress(self):
df = pandas.DataFrame(self.progress, columns=['loss'])
df.plot(ylim=(0, 5.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0))
pass
pass
## scratch
G = Generator()
G.to(device)
# initialise weights
G.apply(weights_init)
print(G(generate_random(100)).view(3,128,128).permute(1,2,0).shape)
plt.figure(figsize = (16,8))
img = G.forward(generate_random(100)).view(3,128,128).permute(1,2,0).detach().cpu().numpy()
img = (img + 1.0) / 2.0
plt.imshow(img, interpolation='none')
torch.Size([128, 128, 3])
<matplotlib.image.AxesImage at 0x7efb82061358>
# create Discriminator and Generator
D = Discriminator()
D.to(device)
# initialise weights
D.apply(weights_init)
G = Generator()
G.to(device)
# initialise weights
G.apply(weights_init)
# free up GPU memory
torch.cuda.empty_cache()
%%time
# train Discriminator and Generator
epochs = 1
for i in range(epochs):
print('training epoch', i+1, "of", epochs)
for image_data_tensor in celeba_dataset:
real = 0.9
fake = 0.0
if(random.random() < 0.05):
real = 0.0
fake = 0.9
pass
# variation
d = 0.1
# train discriminator on real data
D.train(image_data_tensor.permute(2,0,1).view(1, 3, 128, 128), torch.cuda.FloatTensor([real + random.uniform(-d, d)]).view(1,1))
# train discriminator on false
# use detach() so only D is updated, not G
# label softening doesn't apply to 0 labels
D.train(G.forward(generate_random(100)).detach(), torch.cuda.FloatTensor([fake]).view(1,1))
# train generator
G.train(D, generate_random(100), torch.cuda.FloatTensor([1.0]).view(1,1))
pass
pass
training epoch 1 of 1 counter = 60000 counter = 61000 counter = 62000 counter = 63000 counter = 64000 counter = 65000 counter = 66000 counter = 67000 counter = 68000 counter = 69000 counter = 70000 counter = 71000 counter = 72000 counter = 73000 counter = 74000 counter = 75000 counter = 76000 counter = 77000 counter = 78000 counter = 79000 CPU times: user 10min 59s, sys: 2min 47s, total: 13min 47s Wall time: 13min 54s
# plot discriminator error
D.plot_progress()
# plot generator error
G.plot_progress()
# show generator output
G.plot_images()
## expts
seed = torch.zeros(100)
seed[random.randint(0,99)] = 1
#seed[80:89] = 1
plt.figure(figsize = (16,8))
img = G.forward(seed).view(3,128,128).permute(1,2,0).detach().cpu().numpy()
img = (img + 1.0) / 2.0
plt.imshow(img, interpolation='none')
<matplotlib.image.AxesImage at 0x7efb7fe5e3c8>