# simple generative adversarial network
# this version uses simple images, the MNIST dataset
# conventional PyTorch imports
import torch
import torch.nn as nn
#import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random
import pandas
import numpy
import matplotlib.pyplot as plt
# dataset class
class MnistDataset(torch.utils.data.Dataset):
def __init__(self, csv_file):
self.data_df = pandas.read_csv(csv_file, header=None)
pass
def __len__(self):
return len(self.data_df)
def __getitem__(self, index):
# image target (label)
label = self.data_df.iloc[index,0]
image_target = torch.zeros((10))
image_target[label] = 1.0
# image data, normalised from 0-255 to 0-1
image_values = torch.FloatTensor(self.data_df.iloc[index,1:].values) / 255.0
# return label, image data tensor and target tensor
return label, image_values, image_target
def plot_image(self, index):
arr = self.data_df.iloc[index,1:].values.reshape(28,28)
plt.title("label = " + str(self.data_df.iloc[index,0]))
plt.imshow(arr, interpolation='none', cmap='Blues')
pass
pass
# subclass PyTorch dataset class, loads actual data, parses it into targets and pizel data
mnist_dataset = MnistDataset('mnist_data/mnist_train.csv')
mnist_dataset.plot_image(0)
# classifier class
class Classifier(nn.Module):
def __init__(self):
# initialise parent pytorch class
super().__init__()
# define neural network layers
self.model = nn.Sequential(
nn.Linear(784, 200),
nn.Sigmoid(),
nn.Linear(200, 10),
nn.Sigmoid()
)
# create error function
self.error_function = torch.nn.BCELoss()
# create optimiser, using simple stochastic gradient descent
self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)
# counter and accumulator for progress
self.counter = 0;
self.progress = []
pass
def forward(self, inputs):
# simply run model
return self.model(inputs)
def train(self, inputs, targets):
# calculate the output of the network
outputs = self.forward(inputs)
# calculate error
loss = self.error_function(outputs, targets)
# increase counter and accumulate error every 10
self.counter += 1;
if (self.counter % 10 == 0):
self.progress.append(loss.item())
pass
if (self.counter % 10000 == 0):
print("counter = ", self.counter)
pass
# zero gradients, perform a backward pass, and update the weights.
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
pass
def save(self, path):
torch.save(self.state_dict(), path)
pass
def load(self, path):
self.load_state_dict(torch.load(path))
#self.eval()
pass
def plot_progress(self):
df = pandas.DataFrame(self.progress, columns=['loss'])
df.plot(ylim=(0, 1.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))
pass
pass
# train classifier
C = Classifier()
epochs = 3
for i in range(epochs):
print('training epoch', i+1, "of", epochs)
for label, image_data_tensor, target_tensor in mnist_dataset:
C.train(image_data_tensor, target_tensor)
pass
pass
training epoch 1 of 3 counter = 10000 counter = 20000 counter = 30000 counter = 40000 counter = 50000 counter = 60000 training epoch 2 of 3 counter = 70000 counter = 80000 counter = 90000 counter = 100000 counter = 110000 counter = 120000 training epoch 3 of 3 counter = 130000 counter = 140000 counter = 150000 counter = 160000 counter = 170000 counter = 180000
# plot classifier error
C.plot_progress()
# test classifier
# pick a record
record = 16
# see the image and what the correct label should be
mnist_dataset.plot_image(record)
# visualise the answer given by the neural network
image_data = mnist_dataset[record][1]
pandas.DataFrame(C.forward(image_data).detach().numpy()).plot(kind='bar', legend=False)
<matplotlib.axes._subplots.AxesSubplot at 0x125391a90>
# test trained neural network on training data
# subclass PyTorch dataset class, loads actual data, parses it into targets and pizel data
mnist_test_dataset = MnistDataset('mnist_data/mnist_test.csv')
score = 0;
items = 0;
for label, image_data_tensor, target_tensor in mnist_test_dataset:
answer = C.forward(image_data_tensor).detach().numpy()
if (answer.argmax() == label):
score += 1;
pass
items += 1;
pass
print(score, items, score/items)
9037 10000 0.9037