import os
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchdiffeq import odeint
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
import seaborn as sns
sns.set()
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def norm(dim):
return nn.GroupNorm(min(32, dim), dim)
class ResBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(ResBlock, self).__init__()
self.norm1 = norm(inplanes)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.conv1 = conv3x3(inplanes, planes, stride)
self.norm2 = norm(planes)
self.conv2 = conv3x3(planes, planes)
def forward(self, x):
shortcut = x
out = self.relu(self.norm1(x))
if self.downsample is not None:
shortcut = self.downsample(out)
out = self.conv1(out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv2(out)
return out + shortcut
class ConcatConv2d(nn.Module):
def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
super(ConcatConv2d, self).__init__()
module = nn.ConvTranspose2d if transpose else nn.Conv2d
self._layer = module(
dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
bias=bias
)
def forward(self, t, x):
tt = torch.ones_like(x[:, :1, :, :]) * t
ttx = torch.cat([tt, x], 1)
return self._layer(ttx)
class ODEfunc_compl(nn.Module):
def __init__(self, dim):
super(ODEfunc, self).__init__()
self.norm1 = norm(dim)
self.relu = nn.ReLU(inplace=True)
self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm2 = norm(dim)
self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm3 = norm(dim)
self.conv3 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm4 = norm(dim)
self.conv4 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm5 = norm(dim)
self.nfe = 0
def forward(self, t, x):
self.nfe += 1
out = self.norm1(x)
out = self.relu(out)
out = self.conv1(t, out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv2(t, out)
out = self.norm3(out)
out = self.relu(out)
out = self.conv3(t, out)
out = self.norm4(out)
out = self.relu(out)
out = self.conv4(t, out)
out = self.norm5(out)
return out
class ODEfunc(nn.Module):
def __init__(self, dim):
super(ODEfunc, self).__init__()
self.norm1 = norm(dim)
self.relu = nn.ReLU(inplace=True)
self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm2 = norm(dim)
self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm3 = norm(dim)
self.nfe = 0
def forward(self, t, x):
self.nfe += 1
out = self.norm1(x)
out = self.relu(out)
out = self.conv1(t, out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv2(t, out)
out = self.norm3(out)
return out
class ODEBlock(nn.Module):
def __init__(self, odefunc):
super(ODEBlock, self).__init__()
self.odefunc = odefunc
self.integration_time = torch.tensor([0, 1]).float()
def forward(self, x):
self.integration_time = self.integration_time.type_as(x)
out = odeint(self.odefunc, x, self.integration_time, rtol=tol, atol=tol)
return out[1]
@property
def nfe(self):
return self.odefunc.nfe
@nfe.setter
def nfe(self, value):
self.odefunc.nfe = value
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
class RunningAverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, momentum=0.99):
self.momentum = momentum
self.reset()
def reset(self):
self.val = None
self.avg = 0
def update(self, val):
if self.val is None:
self.avg = val
else:
self.avg = self.avg * self.momentum + val * (1 - self.momentum)
self.val = val
def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0):
if data_aug:
transform_train = transforms.Compose([
transforms.RandomCrop(28, padding=4),
transforms.ToTensor(),
])
else:
transform_train = transforms.Compose([
transforms.ToTensor(),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
])
train_loader = DataLoader(
datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
shuffle=True, num_workers=2, drop_last=True
)
train_eval_loader = DataLoader(
datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
)
test_loader = DataLoader(
datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
)
return train_loader, test_loader, train_eval_loader
def inf_generator(iterable):
"""Allows training with DataLoaders in a single infinite loop:
for i, (x, y) in enumerate(inf_generator(train_loader)):
"""
iterator = iterable.__iter__()
while True:
try:
yield iterator.__next__()
except StopIteration:
iterator = iterable.__iter__()
def learning_rate_with_decay(lr, batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
initial_learning_rate = lr * batch_size / batch_denom
boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
vals = [initial_learning_rate * decay for decay in decay_rates]
def learning_rate_fn(itr):
lt = [itr < b for b in boundaries] + [True]
i = np.argmax(lt)
return vals[i]
return learning_rate_fn
def one_hot(x, K):
return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)
def accuracy(model, dataset_loader):
total_correct = 0
for x, y in dataset_loader:
x = x.to(device)
y = one_hot(np.array(y.numpy()), 10)
target_class = np.argmax(y, axis=1)
predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
total_correct += np.sum(predicted_class == target_class)
return total_correct / len(dataset_loader.dataset)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def makedirs(dirname):
if not os.path.exists(dirname):
os.makedirs(dirname)
def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
logger = logging.getLogger()
if debug:
level = logging.DEBUG
else:
level = logging.INFO
logger.setLevel(level)
if saving:
info_file_handler = logging.FileHandler(logpath, mode="a")
info_file_handler.setLevel(level)
logger.addHandler(info_file_handler)
if displaying:
console_handler = logging.StreamHandler()
console_handler.setLevel(level)
logger.addHandler(console_handler)
#logger.info(filepath)
#with open(filepath, "r") as f:
# logger.info(f.read())
for f in package_files:
logger.info(f)
with open(f, "r") as package_f:
logger.info(package_f.read())
return logger
tol = 1e-3
downsampling_method = 'res'
epochs=100
data_aug = True
lr = 1e-3
batch_size = 128
test_bs = 1024 #??
save_dir = './exp1'
#from torchdiffeq import odeint_adjoint as odeint
from torchdiffeq import odeint
makedirs(save_dir)
logger = get_logger(logpath=os.path.join(save_dir, 'logs'), filepath=os.path.abspath('.'))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
is_odenet = True
if downsampling_method == 'conv':
downsampling_layers = [
nn.Conv2d(1, 64, 3, 1),
norm(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 4, 2, 1),
norm(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 4, 2, 1),
]
elif downsampling_method == 'res':
downsampling_layers = [
nn.Conv2d(1, 64, 3, 1),
ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
]
#feature_layers = [ODEBlock(ODEfunc(64)) for _ in range(2)] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
#feature_layers = [ODEBlock(ODEfunc_compl(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]
model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)
#logger.info
print(model)
#logger.info
print('Number of parameters: {}'.format(count_parameters(model)))
criterion = nn.CrossEntropyLoss().to(device)
train_loader, test_loader, train_eval_loader = get_mnist_loaders(data_aug, batch_size, test_bs)
data_gen = inf_generator(train_loader)
batches_per_epoch = len(train_loader)
lr_fn = learning_rate_with_decay( lr, batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch,
boundary_epochs=[20, 40, 60],decay_rates=[1, 0.1, 0.01, 0.001])
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
best_acc = 0
field_names = list(map(lambda l:l.strip().split()[0],
"Epoch {:4d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} |"
"Train_Acc {:.4f} | Test_Acc {:.4f} | Total_Time {:.3f}".split('|')))
logs_1 = []
batch_time_meter = RunningAverageMeter()
f_nfe_meter = RunningAverageMeter()
b_nfe_meter = RunningAverageMeter()
end = time.time()
start = time.time()
for itr in range(epochs * batches_per_epoch):
for param_group in optimizer.param_groups:
param_group['lr'] = lr_fn(itr)
optimizer.zero_grad()
x, y = data_gen.__next__()
x = x.to(device)
y = y.to(device)
logits = model(x)
loss = criterion(logits, y)
if is_odenet:
nfe_forward = feature_layers[0].nfe
feature_layers[0].nfe = 0
loss.backward()
optimizer.step()
if is_odenet:
nfe_backward = feature_layers[0].nfe
feature_layers[0].nfe = 0
batch_time_meter.update(time.time() - end)
if is_odenet:
f_nfe_meter.update(nfe_forward)
b_nfe_meter.update(nfe_backward)
end = time.time()
if itr % batches_per_epoch == 0:
finish = time.time()
with torch.no_grad():
train_acc = accuracy(model, train_eval_loader)
val_acc = accuracy(model, test_loader)
if val_acc > best_acc:
torch.save({'state_dict': model.state_dict()}, os.path.join(save_dir, 'model.pth'))
best_acc = val_acc
#logger.info
print(
"Epoch {:4d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | "
"Train_Acc {:.4f} | Test_Acc {:.4f} | Total_Time {:.3f}".format(
itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
b_nfe_meter.avg, train_acc, val_acc, finish - start
)
)
logs_1.append(dict(zip(field_names, [ itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
b_nfe_meter.avg, train_acc, val_acc, finish - start])))
start = time.time()
Sequential( (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1)) (1): ResBlock( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (relu): ReLU(inplace) (downsample): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2), bias=False) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (2): ResBlock( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (relu): ReLU(inplace) (downsample): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2), bias=False) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (3): ODEBlock( (odefunc): ODEfunc( (norm1): GroupNorm(32, 64, eps=1e-05, affine=True) (relu): ReLU(inplace) (conv1): ConcatConv2d( (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (norm2): GroupNorm(32, 64, eps=1e-05, affine=True) (conv2): ConcatConv2d( (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (norm3): GroupNorm(32, 64, eps=1e-05, affine=True) ) ) (4): GroupNorm(32, 64, eps=1e-05, affine=True) (5): ReLU(inplace) (6): AdaptiveAvgPool2d(output_size=(1, 1)) (7): Flatten() (8): Linear(in_features=64, out_features=10, bias=True) ) Number of parameters: 232970 Epoch 0 | Time 1.166 (1.166) | NFE-F 26.0 | NFE-B 0.0 | Train_Acc 0.0978 | Test_Acc 0.0897 | Total_Time 1.166 Epoch 1 | Time 1.124 (0.352) | NFE-F 26.1 | NFE-B 0.0 | Train_Acc 0.9630 | Test_Acc 0.8998 | Total_Time 151.212 Epoch 2 | Time 1.169 (0.356) | NFE-F 26.2 | NFE-B 0.0 | Train_Acc 0.9720 | Test_Acc 0.9077 | Total_Time 159.991 Epoch 3 | Time 1.162 (0.354) | NFE-F 26.2 | NFE-B 0.0 | Train_Acc 0.9730 | Test_Acc 0.9052 | Total_Time 160.550 Epoch 4 | Time 1.203 (0.341) | NFE-F 25.0 | NFE-B 0.0 | Train_Acc 0.9805 | Test_Acc 0.9138 | Total_Time 157.097 Epoch 5 | Time 1.123 (0.320) | NFE-F 22.9 | NFE-B 0.0 | Train_Acc 0.9821 | Test_Acc 0.9130 | Total_Time 147.981 Epoch 6 | Time 1.177 (0.295) | NFE-F 20.9 | NFE-B 0.0 | Train_Acc 0.9808 | Test_Acc 0.9129 | Total_Time 135.925 Epoch 7 | Time 1.094 (0.286) | NFE-F 20.2 | NFE-B 0.0 | Train_Acc 0.9808 | Test_Acc 0.9141 | Total_Time 128.339 Epoch 8 | Time 1.158 (0.285) | NFE-F 20.2 | NFE-B 0.0 | Train_Acc 0.9823 | Test_Acc 0.9148 | Total_Time 128.412 Epoch 9 | Time 1.123 (0.286) | NFE-F 20.1 | NFE-B 0.0 | Train_Acc 0.9827 | Test_Acc 0.9146 | Total_Time 128.836 Epoch 10 | Time 1.143 (0.289) | NFE-F 20.2 | NFE-B 0.0 | Train_Acc 0.9813 | Test_Acc 0.9134 | Total_Time 130.564 Epoch 11 | Time 1.171 (0.296) | NFE-F 20.1 | NFE-B 0.0 | Train_Acc 0.9812 | Test_Acc 0.9130 | Total_Time 131.606
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) <ipython-input-23-b37de11de66c> in <module> 69 x = x.to(device) 70 y = y.to(device) ---> 71 logits = model(x) 72 loss = criterion(logits, y) 73 ~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs) 491 result = self._slow_forward(*input, **kwargs) 492 else: --> 493 result = self.forward(*input, **kwargs) 494 for hook in self._forward_hooks.values(): 495 hook_result = hook(self, input, result) ~\Anaconda3\lib\site-packages\torch\nn\modules\container.py in forward(self, input) 90 def forward(self, input): 91 for module in self._modules.values(): ---> 92 input = module(input) 93 return input 94 ~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs) 491 result = self._slow_forward(*input, **kwargs) 492 else: --> 493 result = self.forward(*input, **kwargs) 494 for hook in self._forward_hooks.values(): 495 hook_result = hook(self, input, result) <ipython-input-3-032abf9aeec6> in forward(self, x) 17 shortcut = self.downsample(out) 18 ---> 19 out = self.conv1(out) 20 out = self.norm2(out) 21 out = self.relu(out) ~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs) 491 result = self._slow_forward(*input, **kwargs) 492 else: --> 493 result = self.forward(*input, **kwargs) 494 for hook in self._forward_hooks.values(): 495 hook_result = hook(self, input, result) ~\Anaconda3\lib\site-packages\torch\nn\modules\conv.py in forward(self, input) 336 _pair(0), self.dilation, self.groups) 337 return F.conv2d(input, self.weight, self.bias, self.stride, --> 338 self.padding, self.dilation, self.groups) 339 340 KeyboardInterrupt:
logs_1
[{'Epoch': 0, 'Time': 1.1663413047790527, 'NFE-F': 1.1663413047790527, 'NFE-B': 26, 'Train_Acc': 0, 'Test_Acc': 0.09775, 'Total_Time': 0.0897}, {'Epoch': 1, 'Time': 1.1237423419952393, 'NFE-F': 0.3518146520107296, 'NFE-B': 26.113337826255613, 'Train_Acc': 0.0, 'Test_Acc': 0.9629833333333333, 'Total_Time': 0.8998}, {'Epoch': 2, 'Time': 1.1690211296081543, 'NFE-F': 0.35646780732408767, 'NFE-B': 26.160499026848903, 'Train_Acc': 0.0, 'Test_Acc': 0.972, 'Total_Time': 0.9077}, {'Epoch': 3, 'Time': 1.1619246006011963, 'NFE-F': 0.3541097996338256, 'NFE-B': 26.160926447915926, 'Train_Acc': 0.0, 'Test_Acc': 0.9729833333333333, 'Total_Time': 0.9052}, {'Epoch': 4, 'Time': 1.202782154083252, 'NFE-F': 0.34110854730643125, 'NFE-B': 24.95415771513177, 'Train_Acc': 0.0, 'Test_Acc': 0.9805, 'Total_Time': 0.9138}, {'Epoch': 5, 'Time': 1.12302827835083, 'NFE-F': 0.3201531668305314, 'NFE-B': 22.928928317852208, 'Train_Acc': 0.0, 'Test_Acc': 0.9820666666666666, 'Total_Time': 0.913}, {'Epoch': 6, 'Time': 1.177217721939087, 'NFE-F': 0.2949909653493132, 'NFE-B': 20.920042885241745, 'Train_Acc': 0.0, 'Test_Acc': 0.9807833333333333, 'Total_Time': 0.9129}, {'Epoch': 7, 'Time': 1.0944006443023682, 'NFE-F': 0.28566477432615045, 'NFE-B': 20.216727656597406, 'Train_Acc': 0.0, 'Test_Acc': 0.9808, 'Total_Time': 0.9141}, {'Epoch': 8, 'Time': 1.1575982570648193, 'NFE-F': 0.28541028713199734, 'NFE-B': 20.184034851164867, 'Train_Acc': 0.0, 'Test_Acc': 0.9823, 'Total_Time': 0.9148}, {'Epoch': 9, 'Time': 1.1227002143859863, 'NFE-F': 0.28647260674155234, 'NFE-B': 20.125482535545512, 'Train_Acc': 0.0, 'Test_Acc': 0.9827333333333333, 'Total_Time': 0.9146}, {'Epoch': 10, 'Time': 1.1429314613342285, 'NFE-F': 0.28855354211729095, 'NFE-B': 20.15380365074734, 'Train_Acc': 0.0, 'Test_Acc': 0.9812666666666666, 'Total_Time': 0.9134}, {'Epoch': 11, 'Time': 1.1712427139282227, 'NFE-F': 0.29551518088596457, 'NFE-B': 20.124064572092596, 'Train_Acc': 0.0, 'Test_Acc': 0.9812, 'Total_Time': 0.913}]
plt.figure(figsize=(15, 10))
plt.xlabel('epoch number')
plt.ylabel('accuracy')
plt.title('ODENET complex block')
plt.plot(np.arange(len(train_ac_3[1:])) + 1, train_ac_3[1:], label='Accuracy on training set', marker='v')
plt.plot(np.arange(len(test_ac_3[1:])) + 1, test_ac_3[1:], label='Accuracy on test set', marker='o')
plt.legend()
plt.savefig('odenet_3.pdf', bbox_inches='tight')
test_ac_3 = [elem['Total_Time'] for elem in logs_3]
train_ac_3 = [elem['Test_Acc'] for elem in logs_3]
test_ac_1 = [elem['Total_Time'] for elem in logs_1]
train_ac_1 = [elem['Test_Acc'] for elem in logs_1]
plt.figure(figsize=(15, 10))
plt.xlabel('epoch number')
plt.ylabel('accuracy')
plt.title('ODENET with one block')
plt.plot(np.arange(len(train_ac_1[1:])) + 1, train_ac_1[1:], label='Accuracy on training set', marker='v')
plt.plot(np.arange(len(test_ac_1[1:])) + 1, test_ac_1[1:], label='Accuracy on test set', marker='o')
plt.legend()
plt.savefig('odenet_1.pdf', bbox_inches='tight')