In [1]:
import os
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm_notebook as tqdm

from fixed_grid import Euler, Midpoint, RK4
from dopri5 import Dopri5Solver

from misc import _decreasing, _check_inputs,  _flatten, _flatten_convert_none_to_zeros


Функции ODEINT и odeint_adjoint применяют численные методы для вычисления выхода модели. Причем в odeint_adjoint используются сопряженные переменные, поэтому объем его используемой памяти - константа O(1), в отличие от ODEINT.

In [2]:
def ODEINT(func, y0, t, rtol=1e-7, atol=1e-9, method=None, options=None):
"""Integrate a system of ordinary differential equations.
Solves the initial value problem for a non-stiff system of first order ODEs:

dy/dt = func(t, y), y(t[0]) = y0

where y is a Tensor of any shape.
Output dtypes and numerical precision are based on the dtypes of the inputs y0.
Args:
func: Function that maps a Tensor holding the state y and a scalar Tensor
t into a Tensor of state derivatives with respect to time.
y0: N-D Tensor giving starting value of y at time point t[0]. May
have any floating point or complex dtype.
t: 1-D Tensor holding a sequence of time points for which to solve for
y. The initial time point should be the first element of this sequence,
and each time must be larger than the previous time. May have any floating
point dtype. Converted to a Tensor with float64 dtype.
rtol: optional float64 Tensor specifying an upper bound on relative error,
per element of y.
atol: optional float64 Tensor specifying an upper bound on absolute error,
per element of y.
method: optional string indicating the integration method to use.
options: optional dict of configuring options for the indicated integration
method. Can only be provided if a method is explicitly set.
name: Optional name for this operation.
Returns:
y: Tensor, where the first dimension corresponds to different
time points. Contains the solved value of y for each desired time point in
t, with the initial value y0 being the first element along the first
dimension.
Raises:
ValueError: if an invalid method is provided.
TypeError: if options is supplied without method, or if t or y0 has
an invalid dtype.
"""

tensor_input, func, y0, t = _check_inputs(func, y0, t)

if options is None:
options = {}
elif method is None:
raise ValueError('cannot supply options without specifying method')

if method is None:
method = 'dopri5' # ЧМЫ

solver = SOLVERS[method](func, y0, rtol=rtol, atol=atol, **options)
solution = solver.integrate(t)

if tensor_input:
solution = solution[0]
return solution

@staticmethod
def forward(ctx, *args):
assert len(args) >= 8, 'Internal error: all arguments required.'
y0, func, t, flat_params, rtol, atol, method, options = \
args[:-7], args[-7], args[-6], args[-5], args[-4], args[-3], args[-2], args[-1]

ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options = func, rtol, atol, method, options

ans = ODEINT(func, y0, t, rtol=rtol, atol=atol, method=method, options=options)
ctx.save_for_backward(t, flat_params, *ans)
return ans

@staticmethod

t, flat_params, *ans = ctx.saved_tensors
ans = tuple(ans)
func, rtol, atol, method, options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options
n_tensors = len(ans)
f_params = tuple(func.parameters())

# TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
def augmented_dynamics(t, y_aug):
# Dynamics of the original system augmented with
# the adjoint wrt y, and an integrator wrt t and args.

y = tuple(y_.detach().requires_grad_(True) for y_ in y)
func_eval = func(t, y)
func_eval, (t,) + y + f_params,
)
vjp_y = vjp_y_and_params[:n_tensors]
vjp_params = vjp_y_and_params[n_tensors:]

vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
vjp_y = tuple(torch.zeros_like(y_) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y))
vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)

if len(f_params) == 0:
vjp_params = torch.tensor(0.).to(vjp_y[0])
return (*func_eval, *vjp_y, vjp_t, vjp_params)

T = ans[0].shape[0]
time_vjps = []
for i in range(T - 1, 0, -1):

ans_i = tuple(ans_[i] for ans_ in ans)
func_i = func(t[i], ans_i)

# Compute the effect of moving the current time measurement point.
dLd_cur_t = sum(
)
time_vjps.append(dLd_cur_t)

# Run the augmented system backwards in time.
aug_ans = ODEINT(
augmented_dynamics, aug_y0,
torch.tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options
)

# Unpack aug_ans.
adj_params = aug_ans[2 * n_tensors + 1]

del aug_y0, aug_ans

time_vjps = torch.cat(time_vjps[::-1])

def odeint_adjoint(func, y0, t, rtol=1e-6, atol=1e-12, method=None, options=None):

# We need this in order to access the variables inside this module,
# since we have no other way of getting variables along the execution path.
if not isinstance(func, nn.Module):
raise ValueError('func is required to be an instance of nn.Module.')

tensor_input = False
if torch.is_tensor(y0):

class TupleFunc(nn.Module):

def __init__(self, base_func):
super(TupleFunc, self).__init__()
self.base_func = base_func

def forward(self, t, y):
return (self.base_func(t, y[0]),)

tensor_input = True
y0 = (y0,)
func = TupleFunc(func)

flat_params = _flatten(func.parameters())
ys = OdeintAdjointMethod.apply(*y0, func, t, flat_params, rtol, atol, method, options)

if tensor_input:
ys = ys[0]
return ys


Вспомогательные функции

In [3]:
class Flatten(nn.Module):

def __init__(self):
super(Flatten, self).__init__()

def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)

class RunningAverageMeter(object):
"""Computes and stores the average and current value"""

def __init__(self, momentum=0.99):
self.momentum = momentum
self.reset()

def reset(self):
self.val = None
self.avg = 0

def update(self, val):
if self.val is None:
self.avg = val
else:
self.avg = self.avg * self.momentum + val * (1 - self.momentum)
self.val = val

# загружает данные
if data_aug:
transform_train = transforms.Compose([
transforms.ToTensor(),
])
else:
transform_train = transforms.Compose([
transforms.ToTensor(),
])

transform_test = transforms.Compose([
transforms.ToTensor(),
])

shuffle=True, num_workers=0, drop_last=True
)

batch_size=test_batch_size, shuffle=False, num_workers=0, drop_last=True
)

batch_size=test_batch_size, shuffle=False, num_workers=0, drop_last=True
)

def inf_generator(iterable):
"""Allows training with DataLoaders in a single infinite loop:
for i, (x, y) in enumerate(inf_generator(train_loader)):
"""
iterator = iterable.__iter__()
while True:
try:
yield iterator.__next__()
except StopIteration:
iterator = iterable.__iter__()

def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
# реализует затухание lr
initial_learning_rate = LR * batch_size / batch_denom

boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
vals = [initial_learning_rate * decay for decay in decay_rates]

def learning_rate_fn(itr):
lt = [itr < b for b in boundaries] + [True]
i = np.argmax(lt)
return vals[i]

return learning_rate_fn

def one_hot(x, K):
#one hot кодирование
return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)

total_correct = 0
x = x.to(device)
y = one_hot(np.array(y.numpy()), 10)

target_class = np.argmax(y, axis=1)
predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
total_correct += np.sum(predicted_class == target_class)
return (total_correct / len(dataset_loader.dataset)) * 100 # просто точность

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad) # число  параметров модели,что показывает её сложность

def makedirs(dirname): # создаёт дирректорию
if not os.path.exists(dirname):
os.makedirs(dirname)

def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

def norm(dim):
"""нормализация: https://arxiv.org/pdf/1803.08494.pdf"""
return nn.GroupNorm(min(32, dim), dim)


Res блоки задают архитектуру остаточной сети

In [4]:
class ResBlock(nn.Module):
"""Блок для ResNet"""
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(ResBlock, self).__init__()
self.norm1 = norm(inplanes)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.conv1 = conv3x3(inplanes, planes, stride)
self.norm2 = norm(planes)
self.conv2 = conv3x3(planes, planes)

def forward(self, x):
shortcut = x

out = self.relu(self.norm1(x))

if self.downsample is not None:
shortcut = self.downsample(out)

out = self.conv1(out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv2(out)

return out + shortcut # F(x) + x


ODE блоки Задают архитектуру ODE сети

In [5]:
class ConcatConv2d(nn.Module):
"""Особая свёртка"""
def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
super(ConcatConv2d, self).__init__()
module = nn.ConvTranspose2d if transpose else nn.Conv2d
self._layer = module(
bias=bias
)

def forward(self, t, x):
tt = torch.ones_like(x[:, :1, :, :]) * t
ttx = torch.cat([tt, x], 1)
return self._layer(ttx)

class ODEfunc(nn.Module):

def __init__(self, dim):
super(ODEfunc, self).__init__()
self.norm1 = norm(dim)
self.relu = nn.ReLU(inplace=True)
self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm2 = norm(dim)
self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm3 = norm(dim)
self.nfe = 0 # number of function evaluations
# объем вычислений, зависит от размера сетки в численном методе

def forward(self, t, x):
self.nfe += 1
out = self.norm1(x)
out = self.relu(out)
out = self.conv1(t, out) # t участвует в свертках, тех самых concat conv
out = self.norm2(out)
out = self.relu(out)
out = self.conv2(t, out)
out = self.norm3(out)
return out

class ODEBlock(nn.Module):

def __init__(self, odefunc):
super(ODEBlock, self).__init__()
self.odefunc = odefunc # класс ODEfunc описан выше
self.integration_time = torch.tensor([0, 1]).float()

def forward(self, x):
self.integration_time = self.integration_time.type_as(x)
out = odeint(self.odefunc, x, self.integration_time, rtol= TOL, atol = TOL, method  = METHOD)
return out[1]

@property
def nfe(self):
return self.odefunc.nfe

@nfe.setter
def nfe(self, value):
self.odefunc.nfe = value


# Параметры эксперимента¶

TOL(tolernce), параметр для функций ODEINT и odeint_adjoint. Этот параметр позволяет адаптировать модель, изменяя его, можно получить более быструю но менее точную модель

METHOD, параметр для функций ODEINT и odeint_adjoint. Этот параметр определяет какой именно численный метод будет использоваться. От него зависит точность и время работы метода.

Параметры эксперимента

1) Чтобы удостовериться в влиянии TOL можно его увеличить и уменьшить и посмотреть соответствеено результат.

2) Также можно поменять параметр METHOD, выбрав его из SOLVERS.

3) Можно сравнить ODE Net с Res Net, посмотрев на число параметров, на скорость сходимости, для этого следует переключать параметр is_odenet.

4) Для того чтобы Проверить утверждение о крнстантной памяти при использовании odeint_adjoint, нужно изменить параметр odeint на odeint_adjoint и значительно увеличить BTCH_SZ например до 200. При таком размере BTCH_SZ и спользовании метода odeint, ноутбук скорее всего упадёт.

5) Если хочется провести точный эксперимент такой же как в статье то можно выставить число эпох (NEPOCHS) на 128, но это будет довольно долго.

In [6]:
TOL = 1e-3 # испоьзуется в odeint
SOLVERS = {'dopri5': Dopri5Solver, 'euler': Euler, 'midpoint': Midpoint, 'rk4': RK4} # Численные методы
METHOD = 'rk4'# Выбрать из 'dopri5' 'euler' 'midpoint' 'rk4'
LR = 0.1 # используется в learning_rate_with_decay
odeint = ODEINT# ODEINT или odeint_adjoint
is_odenet = True # тут может быть False, тогда будет в эксперименте участвовать resnet а не ODE-Net
BTCH_SZ = 50
NEPOCHS = 5


# Эксперимент¶

In [7]:
makedirs("./experiment")
device = 'cpu'

In [8]:
downsampling_layers = [
nn.Conv2d(1, 64, 3, 1),
norm(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 4, 2, 1),
norm(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 4, 2, 1),
]

feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]


инициализация модели

In [9]:
# сама модель
model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)

# Для оптимизации
criterion = nn.CrossEntropyLoss().to(device)

data_aug=True, batch_size = BTCH_SZ, test_batch_size=1000)

# Получаем бесконечный итератор

# Уменьшение lr от итерации

lr_fn = learning_rate_with_decay(
batch_size = BTCH_SZ, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
decay_rates=[1, 0.1, 0.01, 0.001]
)

# Оптимизатор

optimizer = torch.optim.SGD(model.parameters(), lr= LR, momentum=0.9)

# Инициализируем 3 различных RunningAverageMeter
best_acc = 0
batch_time_meter = RunningAverageMeter()
f_nfe_meter = RunningAverageMeter()
b_nfe_meter = RunningAverageMeter()
end = time.time()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Processing...
Done!


число параметров

можно заметить что у ResNet их больше

In [10]:
print(count_parameters(nn.Sequential(*downsampling_layers)),
count_parameters(nn.Sequential(*feature_layers)), count_parameters(nn.Sequential(*fc_layers)))

print('Number of parameters: {}'.format(count_parameters(model)) )

132096 75392 778
Number of parameters: 208266


запуск эксперимента

In [11]:
# Главный цикл
for itr in tqdm(range(1, NEPOCHS * batches_per_epoch + 1)):

for param_group in optimizer.param_groups:
param_group['lr'] = lr_fn(itr)

x, y = data_gen.__next__()
x = x.to(device)
y = y.to(device)
logits = model(x)
loss = criterion(logits, y)

if is_odenet:
nfe_forward = feature_layers[0].nfe
feature_layers[0].nfe = 0

loss.backward()
optimizer.step()

if is_odenet:
nfe_backward = feature_layers[0].nfe
feature_layers[0].nfe = 0

batch_time_meter.update(time.time() - end)
if is_odenet:
f_nfe_meter.update(nfe_forward)
b_nfe_meter.update(nfe_backward)
end = time.time()

if itr % batches_per_epoch == 0:

epoch:  1.0 te_err (%) :  2.969999999999999