import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms, utils
from torch.autograd import Variable
def show_batch(batch):
im = utils.make_grid(batch)
plt.imshow(np.transpose(im.numpy(), (1, 2, 0)))
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data',
download=True,
train=True,
transform=transforms.Compose([
transforms.ToTensor(), # first, convert image to PyTorch tensor
transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
])),
batch_size=10, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data',
download=True,
train=False,
transform=transforms.Compose([
transforms.ToTensor(), # first, convert image to PyTorch tensor
transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
])),
batch_size=10, shuffle=True)
class CNNClassifier(nn.Module):
def __init__(self):
super(CNNClassifier, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.dropout = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.dropout(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
clf = CNNClassifier()
opt = optim.SGD(clf.parameters(), lr=0.01, momentum=0.9)
train_size = 1000
test_size = 100
def train():
clf.train()
step_id = []
test_losses = []
test_accuracies = []
train_losses = []
train_accuracies = []
for batch_id, (data, label) in enumerate(train_loader):
if batch_id > train_size:
break
data = Variable(data)
target = Variable(label)
opt.zero_grad()
preds = clf(data)
loss = F.nll_loss(preds, target)
loss.backward()
opt.step()
if batch_id % 100 == 0:
step_id.append(batch_id)
test_loss, test_acc = check(batch_id)
test_losses.append(test_loss)
test_accuracies.append(test_acc)
train_loss, train_acc = check(batch_id, train=True)
train_losses.append(train_loss)
train_accuracies.append(train_acc)
return step_id, test_losses, test_accuracies, train_losses, train_accuracies
def check(step_id, train=False):
clf.eval() # set model in inference mode (need this because of dropout)
loss = 0
batch_id = 0
correct = 0
keyword = "test"
loader = test_loader
if train:
loader = train_loader
keyword = "train"
for data, target in loader:
if batch_id > test_size:
break
with torch.no_grad():
data = Variable(data)
target = Variable(target)
output = clf(data)
correct += (torch.max(output.data, 1)[1] == target).sum()
loss += F.nll_loss(output, target).item()
batch_id += 1
loss /= batch_id # loss function already averages over batch size
acc = float(correct) / batch_id * loader.batch_size
print("{0} step: {1} | loss: {2:.4f} | accuracy: {3:.4f}".format(
keyword,
step_id,
loss,
acc
))
return loss, acc
steps, test_losses, test_accuracies, train_losses, train_accuracies = train()
test step: 0 | loss: 2.3156 | accuracy: 2.6733 train step: 0 | loss: 2.3125 | accuracy: 3.1683 test step: 100 | loss: 0.5973 | accuracy: 80.0990 train step: 100 | loss: 0.5346 | accuracy: 83.7624 test step: 200 | loss: 0.4255 | accuracy: 85.8416 train step: 200 | loss: 0.4913 | accuracy: 82.6733 test step: 300 | loss: 0.3554 | accuracy: 88.4158 train step: 300 | loss: 0.3324 | accuracy: 89.4059 test step: 400 | loss: 0.2737 | accuracy: 90.4950 train step: 400 | loss: 0.2891 | accuracy: 91.3861 test step: 500 | loss: 0.2359 | accuracy: 92.0792 train step: 500 | loss: 0.2491 | accuracy: 91.9802 test step: 600 | loss: 0.2778 | accuracy: 92.2772 train step: 600 | loss: 0.2523 | accuracy: 92.5743 test step: 700 | loss: 0.2092 | accuracy: 93.5644 train step: 700 | loss: 0.2833 | accuracy: 92.2772 test step: 800 | loss: 0.3277 | accuracy: 91.8812 train step: 800 | loss: 0.3238 | accuracy: 90.3960 test step: 900 | loss: 0.2016 | accuracy: 94.6535 train step: 900 | loss: 0.1831 | accuracy: 94.6535 test step: 1000 | loss: 0.1458 | accuracy: 95.7426 train step: 1000 | loss: 0.1661 | accuracy: 94.8515
plt.figure(figsize=(16, 6))
ax1=plt.subplot(1, 2, 2)
ax1.plot(steps, train_losses, label="train loss")
ax1.plot(steps, test_losses, label="test loss")
plt.ylabel("NLE loss")
plt.xlabel("step num")
plt.legend()
ax2=plt.subplot(1, 2, 1)
ax2.plot(steps, train_accuracies, label="train")
ax2.plot(steps, test_accuracies, label="test")
plt.ylabel("accuracy")
plt.xlabel("step num")
plt.legend
plt.show()
def noised_data(disp):
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../mnist_data',
download=True,
train=True,
transform=transforms.Compose([
transforms.ToTensor(), # first, convert image to PyTorch tensor
transforms.Normalize((0.1307,), (0.3081,)), # normalize inputs
transforms.Lambda(lambda x: x + torch.Tensor(x.shape).normal_(0, disp))
])),
batch_size=10, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../mnist_data',
download=True,
train=False,
transform=transforms.Compose([
transforms.ToTensor(), # first, convert image to PyTorch tensor
transforms.Normalize((0.1307,), (0.3081,)), # normalize inputs
transforms.Lambda(lambda x: x + torch.Tensor(x.shape).normal_(0, disp))
])),
batch_size=10, shuffle=True)
return train_loader, test_loader
disps = [0.1, 0.2, 0.4, 0.8, 1, 10]
results = []
for (disp_id, disp) in enumerate(disps):
train_loader, test_loader = noised_data(disp)
clf = CNNClassifier()
opt = optim.SGD(clf.parameters(), lr=0.01, momentum=0.9)
results.append(train())
test step: 0 | loss: 2.3050 | accuracy: 7.8218 train step: 0 | loss: 2.3051 | accuracy: 7.2277 test step: 100 | loss: 0.7230 | accuracy: 79.2079 train step: 100 | loss: 0.7597 | accuracy: 78.8119 test step: 200 | loss: 0.3936 | accuracy: 88.2178 train step: 200 | loss: 0.4733 | accuracy: 85.9406 test step: 300 | loss: 0.2314 | accuracy: 92.2772 train step: 300 | loss: 0.2357 | accuracy: 92.9703 test step: 400 | loss: 0.2481 | accuracy: 93.3663 train step: 400 | loss: 0.2706 | accuracy: 91.7822 test step: 500 | loss: 0.1727 | accuracy: 95.5446 train step: 500 | loss: 0.1336 | accuracy: 95.5446 test step: 600 | loss: 0.1631 | accuracy: 94.6535 train step: 600 | loss: 0.1964 | accuracy: 93.4653 test step: 700 | loss: 0.1804 | accuracy: 94.0594 train step: 700 | loss: 0.2348 | accuracy: 92.1782 test step: 800 | loss: 0.1886 | accuracy: 94.0594 train step: 800 | loss: 0.2076 | accuracy: 93.8614 test step: 900 | loss: 0.1270 | accuracy: 96.1386 train step: 900 | loss: 0.1459 | accuracy: 95.5446 test step: 1000 | loss: 0.1585 | accuracy: 95.3465 train step: 1000 | loss: 0.1463 | accuracy: 96.6337 test step: 0 | loss: 2.3074 | accuracy: 10.0000 train step: 0 | loss: 2.3076 | accuracy: 8.9109 test step: 100 | loss: 0.8277 | accuracy: 71.4851 train step: 100 | loss: 0.8025 | accuracy: 71.3861 test step: 200 | loss: 0.4096 | accuracy: 88.8119 train step: 200 | loss: 0.3751 | accuracy: 89.4059 test step: 300 | loss: 0.2473 | accuracy: 92.9703 train step: 300 | loss: 0.2716 | accuracy: 92.3762 test step: 400 | loss: 0.2613 | accuracy: 92.4752 train step: 400 | loss: 0.2333 | accuracy: 92.0792 test step: 500 | loss: 0.2486 | accuracy: 93.4653 train step: 500 | loss: 0.2278 | accuracy: 93.0693 test step: 600 | loss: 0.1431 | accuracy: 95.3465 train step: 600 | loss: 0.1984 | accuracy: 94.1584 test step: 700 | loss: 0.2199 | accuracy: 92.5743 train step: 700 | loss: 0.2144 | accuracy: 93.8614 test step: 800 | loss: 0.2306 | accuracy: 93.1683 train step: 800 | loss: 0.2037 | accuracy: 92.4752 test step: 900 | loss: 0.2168 | accuracy: 93.8614 train step: 900 | loss: 0.2203 | accuracy: 92.7723 test step: 1000 | loss: 0.1638 | accuracy: 94.5545 train step: 1000 | loss: 0.1754 | accuracy: 94.7525 test step: 0 | loss: 2.3041 | accuracy: 17.0297 train step: 0 | loss: 2.3074 | accuracy: 16.9307 test step: 100 | loss: 0.6731 | accuracy: 78.3168 train step: 100 | loss: 0.7055 | accuracy: 76.3366 test step: 200 | loss: 0.4103 | accuracy: 85.8416 train step: 200 | loss: 0.3921 | accuracy: 84.7525 test step: 300 | loss: 0.2703 | accuracy: 91.1881 train step: 300 | loss: 0.3240 | accuracy: 89.9010 test step: 400 | loss: 0.2650 | accuracy: 92.4752 train step: 400 | loss: 0.2231 | accuracy: 92.9703 test step: 500 | loss: 0.2657 | accuracy: 91.3861 train step: 500 | loss: 0.2938 | accuracy: 91.1881 test step: 600 | loss: 0.1974 | accuracy: 93.2673 train step: 600 | loss: 0.1841 | accuracy: 93.9604 test step: 700 | loss: 0.2167 | accuracy: 93.6634 train step: 700 | loss: 0.1536 | accuracy: 94.8515 test step: 800 | loss: 0.1682 | accuracy: 94.4554 train step: 800 | loss: 0.1462 | accuracy: 96.0396 test step: 900 | loss: 0.1138 | accuracy: 96.3366 train step: 900 | loss: 0.1446 | accuracy: 95.3465 test step: 1000 | loss: 0.2650 | accuracy: 93.0693 train step: 1000 | loss: 0.2711 | accuracy: 93.2673 test step: 0 | loss: 2.3199 | accuracy: 10.0000 train step: 0 | loss: 2.3223 | accuracy: 9.5050 test step: 100 | loss: 0.7317 | accuracy: 76.3366 train step: 100 | loss: 0.8359 | accuracy: 73.2673 test step: 200 | loss: 0.4111 | accuracy: 86.9307 train step: 200 | loss: 0.4354 | accuracy: 84.6535 test step: 300 | loss: 0.3767 | accuracy: 88.7129 train step: 300 | loss: 0.3117 | accuracy: 89.4059 test step: 400 | loss: 0.2833 | accuracy: 90.8911 train step: 400 | loss: 0.3485 | accuracy: 89.7030 test step: 500 | loss: 0.2704 | accuracy: 92.1782 train step: 500 | loss: 0.2820 | accuracy: 91.4851 test step: 600 | loss: 0.2025 | accuracy: 93.5644 train step: 600 | loss: 0.2516 | accuracy: 92.0792 test step: 700 | loss: 0.2398 | accuracy: 92.6733 train step: 700 | loss: 0.2198 | accuracy: 92.8713 test step: 800 | loss: 0.2110 | accuracy: 93.3663 train step: 800 | loss: 0.2289 | accuracy: 92.4752 test step: 900 | loss: 0.2384 | accuracy: 92.0792 train step: 900 | loss: 0.2323 | accuracy: 94.0594 test step: 1000 | loss: 0.2373 | accuracy: 93.1683 train step: 1000 | loss: 0.2530 | accuracy: 92.2772 test step: 0 | loss: 2.3243 | accuracy: 6.3366 train step: 0 | loss: 2.3209 | accuracy: 4.7525 test step: 100 | loss: 0.6502 | accuracy: 77.5248 train step: 100 | loss: 0.6559 | accuracy: 77.9208 test step: 200 | loss: 0.4712 | accuracy: 83.7624 train step: 200 | loss: 0.4907 | accuracy: 85.1485 test step: 300 | loss: 0.4522 | accuracy: 84.1584 train step: 300 | loss: 0.4241 | accuracy: 85.8416 test step: 400 | loss: 0.3996 | accuracy: 87.3267 train step: 400 | loss: 0.4431 | accuracy: 85.3465 test step: 500 | loss: 0.3483 | accuracy: 88.9109 train step: 500 | loss: 0.4539 | accuracy: 87.5248 test step: 600 | loss: 0.4028 | accuracy: 87.0297 train step: 600 | loss: 0.3446 | accuracy: 89.1089 test step: 700 | loss: 0.3051 | accuracy: 90.7921 train step: 700 | loss: 0.3136 | accuracy: 89.0099 test step: 800 | loss: 0.4096 | accuracy: 87.3267 train step: 800 | loss: 0.4506 | accuracy: 86.1386 test step: 900 | loss: 0.4134 | accuracy: 88.5149 train step: 900 | loss: 0.3879 | accuracy: 88.1188 test step: 1000 | loss: 0.2437 | accuracy: 93.1683 train step: 1000 | loss: 0.2751 | accuracy: 92.2772 test step: 0 | loss: 2.8023 | accuracy: 10.1980 train step: 0 | loss: 2.8325 | accuracy: 10.1980 test step: 100 | loss: 2.3001 | accuracy: 11.7822 train step: 100 | loss: 2.3000 | accuracy: 13.0693 test step: 200 | loss: 2.2979 | accuracy: 10.7921 train step: 200 | loss: 2.3063 | accuracy: 10.8911 test step: 300 | loss: 2.3019 | accuracy: 10.3960 train step: 300 | loss: 2.3042 | accuracy: 11.5842 test step: 400 | loss: 2.3052 | accuracy: 9.7030 train step: 400 | loss: 2.3013 | accuracy: 11.3861 test step: 500 | loss: 2.3087 | accuracy: 9.0099 train step: 500 | loss: 2.3053 | accuracy: 12.3762 test step: 600 | loss: 2.3042 | accuracy: 10.6931 train step: 600 | loss: 2.3034 | accuracy: 10.4950 test step: 700 | loss: 2.2989 | accuracy: 11.5842 train step: 700 | loss: 2.3035 | accuracy: 12.3762 test step: 800 | loss: 2.3012 | accuracy: 11.0891 train step: 800 | loss: 2.3032 | accuracy: 11.7822 test step: 900 | loss: 2.3038 | accuracy: 9.2079 train step: 900 | loss: 2.3002 | accuracy: 10.2970 test step: 1000 | loss: 2.3041 | accuracy: 9.2079 train step: 1000 | loss: 2.3057 | accuracy: 9.6040
%matplotlib inline
plt.figure(figsize=(16, 16))
subplot_index = 0
for (disp_id, disp) in enumerate(disps):
subplot_index += 1
steps, test_losses, test_accuracies, train_losses, train_accuracies = results[disp_id]
plt.subplot(len(disps), 2, subplot_index)
# plt.plot(x, x ** disp_id)
plt.plot(steps, train_losses, label="train loss")
plt.plot(steps, test_losses, label="test loss")
plt.title("losses, disp = {0}".format(disp))
plt.ylabel("NLE loss")
plt.xlabel("step num")
plt.legend()
subplot_index += 1
plt.subplot(len(disps), 2, subplot_index)
plt.plot(steps, train_accuracies, label="train")
plt.plot(steps, test_accuracies, label="test")
plt.title("accuracy, disp = {0}".format(disp))
plt.ylabel("accuracy")
plt.xlabel("step num")
plt.legend()
plt.subplots_adjust(hspace=0.6)