# conditional generative adversarial network using conv/deconv layers
# this version uses sythetic images at 64x64, and no dense layers in the networks
# mount drive to access csv files
from google.colab import drive
drive.mount('./data')
from google.colab import files
Drive already mounted at ./data; to attempt to forcibly remount, call drive.mount("./data", force_remount=True).
# conventional PyTorch imports
import torch
import torch.nn as nn
from torch.utils.data import Dataset
# GPU
#torch.cuda.is_available()
#torch.cuda.get_device_name(0)
if torch.cuda.is_available():
torch.set_default_tensor_type(torch.cuda.FloatTensor)
print("using cuda:", torch.cuda.get_device_name(0))
pass
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
using cuda: Tesla T4
device(type='cuda')
import random
import pandas
import numpy
import cv2
import matplotlib.pyplot as plt
import datetime
import time
# function to generate uniform random data
# size parameter is length of tensor
def generate_random(size):
#return torch.rand(size)
return torch.randn((1,size))
generate_random(5)
tensor([[-0.5400, 1.2863, 1.7966, 0.6033, 0.1145]])
# function to generate one-hot random data used as a random target label
# size parameter is length of tensor
def generate_random_target(size):
label_tensor = torch.zeros((1,size))
random_idx = random.randint(0, size-1)
label_tensor[0,random_idx] = 1.0
return label_tensor
generate_random_target(5)
tensor([[0., 0., 0., 0., 1.]])
# emotions
list_of_emotions = ['happy', 'angry', 'sad', 'joy']
### generate images (monochrome)
def generate_real(emotion):
# start with blank image
img = numpy.zeros((64,64,3), numpy.uint8) + 255
####
if (emotion == 'happy'):
# happy colour
r = 95
g = 114
b = 231
for i in range(5):
x = random.randint(10, 64-10)
y = random.randint(10, 64-10)
rad = random.randint(5, 5)
r2 = r + random.randint(-50,50)
g2 = g + random.randint(-50,50)
b2 = b + random.randint(-50,50)
cv2.circle(img, (x,y), rad, (r2, g2, b2), 1, cv2.LINE_AA)
pass
pass
####
if (emotion == 'angry'):
# angry colour
r = 169
g = 20
b = 54
for i in range(5):
x1 = random.randint(10, 64-10-30)
y1 = random.randint(10, 64-10-30)
x2 = x1 + 20
y2 = y1 + 20
rad = random.randint(5, 10)
r2 = r + random.randint(-50,50)
g2 = g + random.randint(-50,50)
b2 = b + random.randint(-50,50)
cv2.line(img, (x1,y1), (x2, y2), (r2, g2, b2), 1, cv2.LINE_AA)
pass
####
if (emotion == 'sad'):
# sad colour
r = 0
g = 200
b = 0
startx = random.randint(5,20)
lengthx = 40
starty = random.randint(10,40)
for midy in range(starty, starty + 30, 10):
r2 = r + random.randint(-50,50)
g2 = g + random.randint(-50,50)
b2 = b + random.randint(-50,50)
xvalues = numpy.linspace(0,lengthx,100)
yvalues = numpy.sin(xvalues/4) * 5
for (x,y) in zip(xvalues, yvalues):
x = startx + int(x)
y = midy - int(y)
cv2.rectangle(img,(x,y),(x+1,y+1),(r2,g2,b2), 1, cv2.LINE_AA)
pass
pass
####
if (emotion == 'joy'):
# joy colour
r = 250
g = 250
b = 0
for i in range(3):
cx = random.randint(10, 50)
cy = random.randint(10, 50)
angles = numpy.linspace(0,2*numpy.pi,10)
xvalues = numpy.cos(angles) * 10
yvalues = numpy.sin(angles) * 10
for (x,y) in zip(xvalues, yvalues):
x = int(x)
y = int(y)
r2 = r + random.randint(-20,20)
g2 = g + random.randint(-20,20)
b2 = b + random.randint(-20,20)
cv2.line(img, (cx, cy), (cx + x, cy - y), (r2, g2, b2), 1, cv2.LINE_AA)
pass
####
# convert to float
img = img.astype(numpy.float)
# add random noise
img = img + numpy.random.normal(0, 1, (64, 64,3))*10
img = numpy.clip(img, 0.0, 255.0)
####
img_tensor = torch.cuda.FloatTensor(img).permute(2,0,1).view(1,3,64,64)
# rescale to (-1,1)
img_tensor = (img_tensor / 127.5) - 1.0
label_tensor = torch.zeros((1,len(list_of_emotions)))
label_tensor[0,list_of_emotions.index(emotion)] = 1.0
return label_tensor, img_tensor
### plot images
def plot_real(emotion):
# plot a 3 column, 2 row array of sample images
f, axarr = plt.subplots(2,3, figsize=(16,8))
for i in range(2):
for j in range(3):
label_tensor, img_tensor = generate_real(emotion)
img = img_tensor.permute(0,2,3,1).view(64,64,3).cpu().numpy()
img = (img + 1.0)/2.0
axarr[i,j].imshow(img, interpolation='none', cmap='gray')
axarr[i,j].set_title(emotion)
pass
pass
plot_real('happy')
plot_real('angry')
plot_real('sad')
plot_real('joy')
# from https://github.com/pytorch/vision/issues/720
class View(nn.Module):
def __init__(self, shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.view(*self.shape)
# discriminator class
class Discriminator(nn.Module):
def __init__(self):
# initialise parent pytorch class
super().__init__()
# define neural network layer
# input shape is (1, 1, height, width)
# conv layers
self.model = nn.Sequential(
nn.Conv2d(3, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(256),
nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(256),
nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(256),
nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(256),
nn.Conv2d(256, 1, kernel_size=4, stride=2, bias=False),
nn.LeakyReLU(0.2),
# no final batch norm
View((1,1)),
nn.Sigmoid()
)
# create error function
self.error_function = torch.nn.BCELoss()
# create optimiser, using Adam for better gradient descent
self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0002, betas=(0.5, 0.999))
# counter and accumulator for progress
self.counter = 0;
self.progress = []
pass
def forward(self, image_tensor, target_tensor):
# add condition tensor to image tensor
conditioned_image_tensor = image_tensor + target_tensor.repeat(1, 3, 64, 16)
return self.model(conditioned_image_tensor)
def train(self, image_tensor, condition_tensor, targets):
# calculate the output of the network
outputs = self.forward(image_tensor, condition_tensor)
# calculate error
loss = self.error_function(outputs, targets)
# increase counter and accumulate error every 10
self.counter += 1;
if (self.counter % 50 == 0):
self.progress.append(loss.item())
pass
if (self.counter % 5000 == 0):
print("counter = ", self.counter)
pass
# zero gradients, perform a backward pass, and update the weights.
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
pass
def save(self, path):
torch.save(self.state_dict(), path)
pass
def load(self, path):
self.load_state_dict(torch.load(path))
#self.eval()
pass
def plot_progress(self):
df = pandas.DataFrame(self.progress, columns=['loss'])
df.plot(ylim=(0, 2.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0))
pass
pass
%%time
# create Discriminator and test it
D = Discriminator()
D.to(device)
# train Discriminator
for i in range(2000):
# random emotion from list
emotion = random.choice(list_of_emotions)
# train discriminator on real data
label_tensor, image_tensor = generate_real(emotion)
D.train(image_tensor, label_tensor, torch.cuda.FloatTensor([1.0]).view(1,1))
# train discriminator on false (random) data
D.train(generate_random(3*64*64).view(1, 3, 64, 64), generate_random_target(len(list_of_emotions)), torch.cuda.FloatTensor([0.0]).view(1,1))
if(i%500 == 0):
print("i = ", i)
pass
pass
i = 0 i = 500 i = 1000 i = 1500 CPU times: user 16.7 s, sys: 1.18 s, total: 17.9 s Wall time: 18 s
# plot discriminator error
D.plot_progress()
# manually check D can indeed discriminate between real and fake data
for i in range(4):
print(D.forward(generate_random(3*64*64).view(1, 3, 64, 64), generate_random_target(4)).item())
pass
print()
for i in range(4):
label_tensor, image_data_tensor = generate_real('happy')
print(D.forward(image_data_tensor, label_tensor).item())
pass
print()
for i in range(4):
label_tensor, image_data_tensor = generate_real('angry')
print(D.forward(image_data_tensor, label_tensor).item())
pass
print()
for i in range(4):
label_tensor, image_data_tensor = generate_real('sad')
print(D.forward(image_data_tensor, label_tensor).item())
pass
print()
for i in range(4):
label_tensor, image_data_tensor = generate_real('joy')
print(D.forward(image_data_tensor, label_tensor).item())
pass
0.0007793445256538689 0.0008070397889241576 0.0008125404128804803 0.0008216124260798097 1.0 1.0 0.9999996423721313 1.0 0.999998927116394 1.0 0.9999996423721313 0.9999998807907104 1.0 0.9999998807907104 1.0 1.0 1.0 1.0 0.9999963045120239 0.9999995231628418
# generator class
# this one uses only fully connected (nn.Linear) layers
class Generator(nn.Module):
def __init__(self):
# initialise parent pytorch class
super().__init__()
# define neural network layers
self.model = nn.Sequential(
# reshape seed to tensor
View((1, 100, 1, 1)),
# reshape tensor to 256 filters
nn.ConvTranspose2d(100, 256, kernel_size=4, stride=1, bias=False),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(256),
nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(256),
nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(256),
nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(256),
nn.ConvTranspose2d(256, 3, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(3),
View((1, 3, 64, 64)),
nn.Tanh()
)
# create error function
self.error_function = torch.nn.BCELoss()
# create optimiser, using Adam for better gradient descent
self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0002, betas=(0.5, 0.999))
# counter and accumulator for progress
self.counter = 0;
self.progress = []
pass
def forward(self, noise_tensor, condition_tensor):
# add condition tensor to noise tensor
conditioned_noise_tensor = noise_tensor + condition_tensor.repeat(1,25)
return self.model(conditioned_noise_tensor)
def train(self, D, noise_tensor, condition_tensor, targets):
# calculate the output of the network
g_output = self.forward(noise_tensor, condition_tensor)
# pass onto Discriminator
d_output = D.forward(g_output, condition_tensor)
# calculate error
loss = D.error_function(d_output, targets)
# increase counter and accumulate error every 10
self.counter += 1;
if (self.counter % 100 == 0):
self.progress.append(loss.item())
pass
# zero gradients, perform a backward pass, and update the weights.
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
pass
def save(self, path):
torch.save(self.state_dict(), path)
pass
def load(self, path):
self.load_state_dict(torch.load(path))
#self.eval()
pass
def plot_images(self):
# plot a 3 column, 4 row array of sample images
f, axarr = plt.subplots(len(list_of_emotions), 3, figsize=(16,16))
for i in range(len(list_of_emotions)):
for j in range(3):
condition_tensor = torch.zeros((1,len(list_of_emotions)))
condition_tensor[0,i] = 1.0
img = G.forward(generate_random(100), condition_tensor).permute(0,2,3,1).view(64,64,3).detach().cpu().numpy()
img = (img + 1.0)/2.0
axarr[i,j].imshow(img, interpolation='none', cmap='gray')
axarr[i,j].set_title(list_of_emotions[i])
pass
pass
pass
def plot_progress(self):
df = pandas.DataFrame(self.progress, columns=['loss'])
df.plot(ylim=(0, 10.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0))
pass
pass
# scratch
G = Generator()
img = G(generate_random(100), generate_random_target(4))
img = (img + 1.0) / 2.0
print(img.shape)
print(img.min(), img.mean(), img.max())
plt.figure(figsize = (16,8))
plt.imshow(img.detach().permute(0,2,3,1).view(64,64,3).cpu().numpy(), interpolation='none', cmap='gray')
torch.Size([1, 3, 64, 64]) tensor(0.0332, grad_fn=<MinBackward1>) tensor(0.4848, grad_fn=<MeanBackward0>) tensor(0.9999, grad_fn=<MaxBackward1>)
<matplotlib.image.AxesImage at 0x7f493e6e4da0>
# create Discriminator and Generator
D = Discriminator()
D.to(device)
G = Generator()
G.to(device)
# free up GPU memory
torch.cuda.empty_cache()
%%time
# train Discriminator and Generator
epochs = 8
for epoch in range(epochs):
print("epoch .. ", epoch+1)
for i in range(1000):
# swap labels sometimes
real = 0.9
fake = 0.0
if (random.random() < 0.05):
real = 0.0
fake = 0.9
pass
# random emotion from list
emotion = random.choices(list_of_emotions, weights=[1.0,1.0,1.0, 1.0], k=1)[0]
# repeat for each emotion (batching classes)
for e in range(10):
# train discriminator on real data
label_tensor, image_tensor = generate_real(emotion)
D.train(image_tensor, label_tensor, torch.cuda.FloatTensor([real]).view(1,1))
# train discriminator on false
# use detach() so only D is updated, not G
# label softening doesn't apply to 0 labels
condition_tensor = generate_random_target(len(list_of_emotions))
D.train(G.forward(generate_random(100), condition_tensor).detach(), condition_tensor, torch.cuda.FloatTensor([fake]).view(1,1))
# train generator
condition_tensor = generate_random_target(len(list_of_emotions))
G.train(D, generate_random(100), condition_tensor, torch.cuda.FloatTensor([1.0]).view(1,1))
pass
pass
# plot 12-image samples, save to file
f, axarr = plt.subplots(len(list_of_emotions), 3, figsize=(16,16))
for i in range(len(list_of_emotions)):
for j in range(3):
condition_tensor = torch.zeros((1,len(list_of_emotions)))
condition_tensor[0,i] = 1.0
img = G.forward(generate_random(100), condition_tensor).permute(0,2,3,1).view(64,64,3).detach().cpu().numpy()
img = (img + 1.0)/2.0
axarr[i,j].imshow(img, interpolation='none', cmap='gray')
axarr[i,j].set_title(list_of_emotions[i])
pass
pass
t = datetime.datetime.fromtimestamp(time.time())
fname = "img_%d_%s.png" % (epoch+1, t.strftime('%H-%M-%S'))
plt.savefig(fname)
files.download(fname)
plt.close(f)
pass
epoch .. 1 counter = 5000 counter = 10000 counter = 15000 counter = 20000 epoch .. 2 counter = 25000 counter = 30000 counter = 35000 counter = 40000 epoch .. 3 counter = 45000 counter = 50000 counter = 55000 counter = 60000 epoch .. 4 counter = 65000 counter = 70000 counter = 75000 counter = 80000 epoch .. 5 counter = 85000 counter = 90000 counter = 95000 counter = 100000 epoch .. 6 counter = 105000 counter = 110000 counter = 115000 counter = 120000 epoch .. 7 counter = 125000 counter = 130000 counter = 135000 counter = 140000 epoch .. 8 counter = 145000 counter = 150000 counter = 155000 counter = 160000 CPU times: user 19min 55s, sys: 1min 31s, total: 21min 26s Wall time: 21min 38s
# plot discriminator error
D.plot_progress()
# plot generator error
G.plot_progress()
# show generator outputs as they evolve
G.plot_images()
!ls
data img_3_17-43-55.png img_6_17-52-01.png sample_data img_1_17-38-30.png img_4_17-46-36.png img_7_17-54-43.png img_2_17-41-13.png img_5_17-49-19.png img_8_17-57-24.png