In [17]:
import os
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchdiffeq import odeint
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from matplotlib import pyplot as plt
import seaborn as sns
In [18]:
sns.set()
In [2]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)
In [3]:
class ResBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.norm1 = norm(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = norm(planes)
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        shortcut = x

        out = self.relu(self.norm1(x))

        if self.downsample is not None:
            shortcut = self.downsample(out)

        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + shortcut
In [21]:
class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)


class ODEfunc_compl(nn.Module):

    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.conv3 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm4 = norm(dim)
        self.conv4 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm5 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        out = self.relu(out)
        out = self.conv3(t, out)
        out = self.norm4(out)
        out = self.relu(out)
        out = self.conv4(t, out)
        out = self.norm5(out)
        return out

    
class ODEfunc(nn.Module):

    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time, rtol=tol, atol=tol)
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value
In [5]:
class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val
In [6]:
def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0):
    if data_aug:
        transform_train = transforms.Compose([
            transforms.RandomCrop(28, padding=4),
            transforms.ToTensor(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
        shuffle=True, num_workers=2, drop_last=True
    )

    train_eval_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    test_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    return train_loader, test_loader, train_eval_loader


def inf_generator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, (x, y) in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()
In [7]:
def learning_rate_with_decay(lr, batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
    initial_learning_rate = lr * batch_size / batch_denom

    boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
    vals = [initial_learning_rate * decay for decay in decay_rates]

    def learning_rate_fn(itr):
        lt = [itr < b for b in boundaries] + [True]
        i = np.argmax(lt)
        return vals[i]

    return learning_rate_fn


def one_hot(x, K):
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)


def accuracy(model, dataset_loader):
    total_correct = 0
    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)

        target_class = np.argmax(y, axis=1)
        predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
In [8]:
def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)


def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
    logger = logging.getLogger()
    if debug:
        level = logging.DEBUG
    else:
        level = logging.INFO
    logger.setLevel(level)
    if saving:
        info_file_handler = logging.FileHandler(logpath, mode="a")
        info_file_handler.setLevel(level)
        logger.addHandler(info_file_handler)
    if displaying:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(level)
        logger.addHandler(console_handler)
    #logger.info(filepath)
    #with open(filepath, "r") as f:
    #    logger.info(f.read())

    for f in package_files:
        logger.info(f)
        with open(f, "r") as package_f:
            logger.info(package_f.read())

    return logger
In [9]:
tol = 1e-3
downsampling_method = 'res'
epochs=100
data_aug = True
lr = 1e-3
batch_size = 128
test_bs = 1024 #??

save_dir = './exp1' 
In [22]:
#from torchdiffeq import odeint_adjoint as odeint
from torchdiffeq import odeint
In [23]:
makedirs(save_dir)
logger = get_logger(logpath=os.path.join(save_dir, 'logs'), filepath=os.path.abspath('.'))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

is_odenet = True

if downsampling_method == 'conv':
    downsampling_layers = [
        nn.Conv2d(1, 64, 3, 1),
        norm(64),
        nn.ReLU(inplace=True),
        nn.Conv2d(64, 64, 4, 2, 1),
        norm(64),
        nn.ReLU(inplace=True),
        nn.Conv2d(64, 64, 4, 2, 1),
    ]
elif downsampling_method == 'res':
    downsampling_layers = [
        nn.Conv2d(1, 64, 3, 1),
        ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
        ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
    ]

#feature_layers = [ODEBlock(ODEfunc(64)) for _ in range(2)] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
#feature_layers = [ODEBlock(ODEfunc_compl(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]

fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]

model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)

#logger.info
print(model)
#logger.info
print('Number of parameters: {}'.format(count_parameters(model)))

criterion = nn.CrossEntropyLoss().to(device)

train_loader, test_loader, train_eval_loader = get_mnist_loaders(data_aug, batch_size, test_bs)

data_gen = inf_generator(train_loader)
batches_per_epoch = len(train_loader)

lr_fn = learning_rate_with_decay( lr, batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch,
                                 boundary_epochs=[20, 40, 60],decay_rates=[1, 0.1, 0.01, 0.001])

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

best_acc = 0
field_names = list(map(lambda l:l.strip().split()[0],
                       "Epoch {:4d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} |"
                       "Train_Acc {:.4f} | Test_Acc {:.4f} | Total_Time {:.3f}".split('|')))
logs_1 = []

batch_time_meter = RunningAverageMeter()
f_nfe_meter = RunningAverageMeter()
b_nfe_meter = RunningAverageMeter()
end = time.time()
start = time.time()

for itr in range(epochs * batches_per_epoch):

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr_fn(itr)

    optimizer.zero_grad()
    x, y = data_gen.__next__()
    x = x.to(device)
    y = y.to(device)
    logits = model(x)
    loss = criterion(logits, y)

    if is_odenet:
        nfe_forward = feature_layers[0].nfe
        feature_layers[0].nfe = 0

    loss.backward()
    optimizer.step()

    if is_odenet:
        nfe_backward = feature_layers[0].nfe
        feature_layers[0].nfe = 0

    batch_time_meter.update(time.time() - end)
    if is_odenet:
        f_nfe_meter.update(nfe_forward)
        b_nfe_meter.update(nfe_backward)
    end = time.time()

    if itr % batches_per_epoch == 0:
        finish = time.time()
        with torch.no_grad():
            train_acc = accuracy(model, train_eval_loader)
            val_acc = accuracy(model, test_loader)
            if val_acc > best_acc:
                torch.save({'state_dict': model.state_dict()}, os.path.join(save_dir, 'model.pth'))
                best_acc = val_acc
            #logger.info
            print(
                "Epoch {:4d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | "
                "Train_Acc {:.4f} | Test_Acc {:.4f} | Total_Time {:.3f}".format(
                    itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
                    b_nfe_meter.avg, train_acc, val_acc, finish - start
                )
            )
            logs_1.append(dict(zip(field_names, [ itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
                    b_nfe_meter.avg, train_acc, val_acc, finish - start])))
            start = time.time()
Sequential(
  (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
  (1): ResBlock(
    (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
    (relu): ReLU(inplace)
    (downsample): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (2): ResBlock(
    (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
    (relu): ReLU(inplace)
    (downsample): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (3): ODEBlock(
    (odefunc): ODEfunc(
      (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
      (relu): ReLU(inplace)
      (conv1): ConcatConv2d(
        (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
      (conv2): ConcatConv2d(
        (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (norm3): GroupNorm(32, 64, eps=1e-05, affine=True)
    )
  )
  (4): GroupNorm(32, 64, eps=1e-05, affine=True)
  (5): ReLU(inplace)
  (6): AdaptiveAvgPool2d(output_size=(1, 1))
  (7): Flatten()
  (8): Linear(in_features=64, out_features=10, bias=True)
)
Number of parameters: 232970
Epoch    0 | Time 1.166 (1.166) | NFE-F 26.0 | NFE-B 0.0 | Train_Acc 0.0978 | Test_Acc 0.0897 | Total_Time 1.166
Epoch    1 | Time 1.124 (0.352) | NFE-F 26.1 | NFE-B 0.0 | Train_Acc 0.9630 | Test_Acc 0.8998 | Total_Time 151.212
Epoch    2 | Time 1.169 (0.356) | NFE-F 26.2 | NFE-B 0.0 | Train_Acc 0.9720 | Test_Acc 0.9077 | Total_Time 159.991
Epoch    3 | Time 1.162 (0.354) | NFE-F 26.2 | NFE-B 0.0 | Train_Acc 0.9730 | Test_Acc 0.9052 | Total_Time 160.550
Epoch    4 | Time 1.203 (0.341) | NFE-F 25.0 | NFE-B 0.0 | Train_Acc 0.9805 | Test_Acc 0.9138 | Total_Time 157.097
Epoch    5 | Time 1.123 (0.320) | NFE-F 22.9 | NFE-B 0.0 | Train_Acc 0.9821 | Test_Acc 0.9130 | Total_Time 147.981
Epoch    6 | Time 1.177 (0.295) | NFE-F 20.9 | NFE-B 0.0 | Train_Acc 0.9808 | Test_Acc 0.9129 | Total_Time 135.925
Epoch    7 | Time 1.094 (0.286) | NFE-F 20.2 | NFE-B 0.0 | Train_Acc 0.9808 | Test_Acc 0.9141 | Total_Time 128.339
Epoch    8 | Time 1.158 (0.285) | NFE-F 20.2 | NFE-B 0.0 | Train_Acc 0.9823 | Test_Acc 0.9148 | Total_Time 128.412
Epoch    9 | Time 1.123 (0.286) | NFE-F 20.1 | NFE-B 0.0 | Train_Acc 0.9827 | Test_Acc 0.9146 | Total_Time 128.836
Epoch   10 | Time 1.143 (0.289) | NFE-F 20.2 | NFE-B 0.0 | Train_Acc 0.9813 | Test_Acc 0.9134 | Total_Time 130.564
Epoch   11 | Time 1.171 (0.296) | NFE-F 20.1 | NFE-B 0.0 | Train_Acc 0.9812 | Test_Acc 0.9130 | Total_Time 131.606
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-23-b37de11de66c> in <module>
     69     x = x.to(device)
     70     y = y.to(device)
---> 71     logits = model(x)
     72     loss = criterion(logits, y)
     73 

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

~\Anaconda3\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
     90     def forward(self, input):
     91         for module in self._modules.values():
---> 92             input = module(input)
     93         return input
     94 

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

<ipython-input-3-032abf9aeec6> in forward(self, x)
     17             shortcut = self.downsample(out)
     18 
---> 19         out = self.conv1(out)
     20         out = self.norm2(out)
     21         out = self.relu(out)

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

~\Anaconda3\lib\site-packages\torch\nn\modules\conv.py in forward(self, input)
    336                             _pair(0), self.dilation, self.groups)
    337         return F.conv2d(input, self.weight, self.bias, self.stride,
--> 338                         self.padding, self.dilation, self.groups)
    339 
    340 

KeyboardInterrupt: 
In [24]:
logs_1
Out[24]:
[{'Epoch': 0,
  'Time': 1.1663413047790527,
  'NFE-F': 1.1663413047790527,
  'NFE-B': 26,
  'Train_Acc': 0,
  'Test_Acc': 0.09775,
  'Total_Time': 0.0897},
 {'Epoch': 1,
  'Time': 1.1237423419952393,
  'NFE-F': 0.3518146520107296,
  'NFE-B': 26.113337826255613,
  'Train_Acc': 0.0,
  'Test_Acc': 0.9629833333333333,
  'Total_Time': 0.8998},
 {'Epoch': 2,
  'Time': 1.1690211296081543,
  'NFE-F': 0.35646780732408767,
  'NFE-B': 26.160499026848903,
  'Train_Acc': 0.0,
  'Test_Acc': 0.972,
  'Total_Time': 0.9077},
 {'Epoch': 3,
  'Time': 1.1619246006011963,
  'NFE-F': 0.3541097996338256,
  'NFE-B': 26.160926447915926,
  'Train_Acc': 0.0,
  'Test_Acc': 0.9729833333333333,
  'Total_Time': 0.9052},
 {'Epoch': 4,
  'Time': 1.202782154083252,
  'NFE-F': 0.34110854730643125,
  'NFE-B': 24.95415771513177,
  'Train_Acc': 0.0,
  'Test_Acc': 0.9805,
  'Total_Time': 0.9138},
 {'Epoch': 5,
  'Time': 1.12302827835083,
  'NFE-F': 0.3201531668305314,
  'NFE-B': 22.928928317852208,
  'Train_Acc': 0.0,
  'Test_Acc': 0.9820666666666666,
  'Total_Time': 0.913},
 {'Epoch': 6,
  'Time': 1.177217721939087,
  'NFE-F': 0.2949909653493132,
  'NFE-B': 20.920042885241745,
  'Train_Acc': 0.0,
  'Test_Acc': 0.9807833333333333,
  'Total_Time': 0.9129},
 {'Epoch': 7,
  'Time': 1.0944006443023682,
  'NFE-F': 0.28566477432615045,
  'NFE-B': 20.216727656597406,
  'Train_Acc': 0.0,
  'Test_Acc': 0.9808,
  'Total_Time': 0.9141},
 {'Epoch': 8,
  'Time': 1.1575982570648193,
  'NFE-F': 0.28541028713199734,
  'NFE-B': 20.184034851164867,
  'Train_Acc': 0.0,
  'Test_Acc': 0.9823,
  'Total_Time': 0.9148},
 {'Epoch': 9,
  'Time': 1.1227002143859863,
  'NFE-F': 0.28647260674155234,
  'NFE-B': 20.125482535545512,
  'Train_Acc': 0.0,
  'Test_Acc': 0.9827333333333333,
  'Total_Time': 0.9146},
 {'Epoch': 10,
  'Time': 1.1429314613342285,
  'NFE-F': 0.28855354211729095,
  'NFE-B': 20.15380365074734,
  'Train_Acc': 0.0,
  'Test_Acc': 0.9812666666666666,
  'Total_Time': 0.9134},
 {'Epoch': 11,
  'Time': 1.1712427139282227,
  'NFE-F': 0.29551518088596457,
  'NFE-B': 20.124064572092596,
  'Train_Acc': 0.0,
  'Test_Acc': 0.9812,
  'Total_Time': 0.913}]
In [20]:
plt.figure(figsize=(15, 10))
plt.xlabel('epoch number')
plt.ylabel('accuracy')
plt.title('ODENET complex block')
plt.plot(np.arange(len(train_ac_3[1:])) + 1, train_ac_3[1:], label='Accuracy on training set', marker='v')
plt.plot(np.arange(len(test_ac_3[1:])) + 1, test_ac_3[1:], label='Accuracy on test set', marker='o')
plt.legend()
plt.savefig('odenet_3.pdf', bbox_inches='tight')
In [15]:
test_ac_3 = [elem['Total_Time'] for elem in logs_3]
train_ac_3 = [elem['Test_Acc'] for elem in logs_3]
In [25]:
test_ac_1 = [elem['Total_Time'] for elem in logs_1]
train_ac_1 = [elem['Test_Acc'] for elem in logs_1]
In [26]:
plt.figure(figsize=(15, 10))
plt.xlabel('epoch number')
plt.ylabel('accuracy')
plt.title('ODENET with one block')
plt.plot(np.arange(len(train_ac_1[1:])) + 1, train_ac_1[1:], label='Accuracy on training set', marker='v')
plt.plot(np.arange(len(test_ac_1[1:])) + 1, test_ac_1[1:], label='Accuracy on test set', marker='o')
plt.legend()
plt.savefig('odenet_1.pdf', bbox_inches='tight')