In [1]:
import os
import argparse
import logging 
import time
import numpy as np
import torch
import torch.nn as nn
from import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm_notebook as tqdm

from fixed_grid import Euler, Midpoint, RK4
from dopri5 import Dopri5Solver

from misc import _decreasing, _check_inputs,  _flatten, _flatten_convert_none_to_zeros

ODEINT и odeint_adjoint

Функции ODEINT и odeint_adjoint применяют численные методы для вычисления выхода модели. Причем в odeint_adjoint используются сопряженные переменные, поэтому объем его используемой памяти - константа O(1), в отличие от ODEINT.

In [2]:
def ODEINT(func, y0, t, rtol=1e-7, atol=1e-9, method=None, options=None):
    """Integrate a system of ordinary differential equations.
    Solves the initial value problem for a non-stiff system of first order ODEs:
        dy/dt = func(t, y), y(t[0]) = y0
    where y is a Tensor of any shape.
    Output dtypes and numerical precision are based on the dtypes of the inputs `y0`.
        func: Function that maps a Tensor holding the state `y` and a scalar Tensor
            `t` into a Tensor of state derivatives with respect to time.
        y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
            have any floating point or complex dtype.
        t: 1-D Tensor holding a sequence of time points for which to solve for
            `y`. The initial time point should be the first element of this sequence,
            and each time must be larger than the previous time. May have any floating
            point dtype. Converted to a Tensor with float64 dtype.
        rtol: optional float64 Tensor specifying an upper bound on relative error,
            per element of `y`.
        atol: optional float64 Tensor specifying an upper bound on absolute error,
            per element of `y`.
        method: optional string indicating the integration method to use.
        options: optional dict of configuring options for the indicated integration
            method. Can only be provided if a `method` is explicitly set.
        name: Optional name for this operation.
        y: Tensor, where the first dimension corresponds to different
            time points. Contains the solved value of y for each desired time point in
            `t`, with the initial value `y0` being the first element along the first
        ValueError: if an invalid `method` is provided.
        TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
            an invalid dtype.

    tensor_input, func, y0, t = _check_inputs(func, y0, t)

    if options is None:
        options = {}
    elif method is None:
        raise ValueError('cannot supply `options` without specifying `method`')

    if method is None:
        method = 'dopri5' # ЧМЫ

    solver = SOLVERS[method](func, y0, rtol=rtol, atol=atol, **options)
    solution = solver.integrate(t)

    if tensor_input:
        solution = solution[0]
    return solution

class OdeintAdjointMethod(torch.autograd.Function):

    def forward(ctx, *args):
        assert len(args) >= 8, 'Internal error: all arguments required.'
        y0, func, t, flat_params, rtol, atol, method, options = \
            args[:-7], args[-7], args[-6], args[-5], args[-4], args[-3], args[-2], args[-1]

        ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options = func, rtol, atol, method, options

        with torch.no_grad():
            ans = ODEINT(func, y0, t, rtol=rtol, atol=atol, method=method, options=options)
        ctx.save_for_backward(t, flat_params, *ans)
        return ans

    def backward(ctx, *grad_output):

        t, flat_params, *ans = ctx.saved_tensors
        ans = tuple(ans)
        func, rtol, atol, method, options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options
        n_tensors = len(ans)
        f_params = tuple(func.parameters())

        # TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
        def augmented_dynamics(t, y_aug):
            # Dynamics of the original system augmented with
            # the adjoint wrt y, and an integrator wrt t and args.
            y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors]  # Ignore adj_time and adj_params.

            with torch.set_grad_enabled(True):
                t =[0].device).detach().requires_grad_(True)
                y = tuple(y_.detach().requires_grad_(True) for y_ in y)
                func_eval = func(t, y)
                vjp_t, *vjp_y_and_params = torch.autograd.grad(
                    func_eval, (t,) + y + f_params,
                    tuple(-adj_y_ for adj_y_ in adj_y), allow_unused=True, retain_graph=True
            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_params = vjp_y_and_params[n_tensors:]

            # autograd.grad returns None if no gradient, set to zero.
            vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
            vjp_y = tuple(torch.zeros_like(y_) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y))
            vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)

            if len(f_params) == 0:
                vjp_params = torch.tensor(0.).to(vjp_y[0])
            return (*func_eval, *vjp_y, vjp_t, vjp_params)

        T = ans[0].shape[0]
        with torch.no_grad():
            adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)
            adj_params = torch.zeros_like(flat_params)
            adj_time = torch.tensor(0.).to(t)
            time_vjps = []
            for i in range(T - 1, 0, -1):

                ans_i = tuple(ans_[i] for ans_ in ans)
                grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output)
                func_i = func(t[i], ans_i)

                # Compute the effect of moving the current time measurement point.
                dLd_cur_t = sum(
          , grad_output_i_.reshape(-1)).reshape(1)
                    for func_i_, grad_output_i_ in zip(func_i, grad_output_i)
                adj_time = adj_time - dLd_cur_t

                # Run the augmented system backwards in time.
                if adj_params.numel() == 0:
                    adj_params = torch.tensor(0.).to(adj_y[0])
                aug_y0 = (*ans_i, *adj_y, adj_time, adj_params)
                aug_ans = ODEINT(
                    augmented_dynamics, aug_y0,
                    torch.tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options

                # Unpack aug_ans.
                adj_y = aug_ans[n_tensors:2 * n_tensors]
                adj_time = aug_ans[2 * n_tensors]
                adj_params = aug_ans[2 * n_tensors + 1]

                adj_y = tuple(adj_y_[1] if len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y)
                if len(adj_time) > 0: adj_time = adj_time[1]
                if len(adj_params) > 0: adj_params = adj_params[1]

                adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output))

                del aug_y0, aug_ans

            time_vjps =[::-1])

            return (*adj_y, None, time_vjps, adj_params, None, None, None, None, None)

def odeint_adjoint(func, y0, t, rtol=1e-6, atol=1e-12, method=None, options=None):

    # We need this in order to access the variables inside this module,
    # since we have no other way of getting variables along the execution path.
    if not isinstance(func, nn.Module):
        raise ValueError('func is required to be an instance of nn.Module.')

    tensor_input = False
    if torch.is_tensor(y0):

        class TupleFunc(nn.Module):

            def __init__(self, base_func):
                super(TupleFunc, self).__init__()
                self.base_func = base_func

            def forward(self, t, y):
                return (self.base_func(t, y[0]),)

        tensor_input = True
        y0 = (y0,)
        func = TupleFunc(func)

    flat_params = _flatten(func.parameters())
    ys = OdeintAdjointMethod.apply(*y0, func, t, flat_params, rtol, atol, method, options)

    if tensor_input:
        ys = ys[0]
    return ys

Вспомогательные функции

In [3]:
class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape =[1:])).item()
        return x.view(-1, shape)

class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val

def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0): 
    # загружает данные
    if data_aug:
        transform_train = transforms.Compose([
            transforms.RandomCrop(28, padding=4),
        transform_train = transforms.Compose([

    transform_test = transforms.Compose([

    train_loader = DataLoader(
        datasets.MNIST(root='data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
        shuffle=True, num_workers=0, drop_last=True

    train_eval_loader = DataLoader(
        datasets.MNIST(root='data/mnist', train=True, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=0, drop_last=True

    test_loader = DataLoader(
        datasets.MNIST(root='data/mnist', train=False, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=0, drop_last=True

    return train_loader, test_loader, train_eval_loader #loader

def inf_generator(iterable): 
    """Allows training with DataLoaders in a single infinite loop:
        for i, (x, y) in enumerate(inf_generator(train_loader)):
    iterator = iterable.__iter__()
    while True:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()

def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
    # реализует затухание lr
    initial_learning_rate = LR * batch_size / batch_denom

    boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
    vals = [initial_learning_rate * decay for decay in decay_rates]

    def learning_rate_fn(itr):
        lt = [itr < b for b in boundaries] + [True]
        i = np.argmax(lt)
        return vals[i]

    return learning_rate_fn

def one_hot(x, K):
    #one hot кодирование
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)

def accuracy(model, dataset_loader):
    total_correct = 0
    for x, y in dataset_loader:
        x =
        y = one_hot(np.array(y.numpy()), 10)

        target_class = np.argmax(y, axis=1)
        predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return (total_correct / len(dataset_loader.dataset)) * 100 # просто точность

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) # число  параметров модели,что показывает её сложность

def makedirs(dirname): # создаёт дирректорию
    if not os.path.exists(dirname):

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)

Res блоки задают архитектуру остаточной сети

In [4]:
class ResBlock(nn.Module):
    """Блок для ResNet"""
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.norm1 = norm(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample 
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = norm(planes)
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        shortcut = x

        out = self.relu(self.norm1(x))

        if self.downsample is not None:
            shortcut = self.downsample(out) 

        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + shortcut # F(x) + x

ODE блоки Задают архитектуру ODE сети

In [5]:
class ConcatConv2d(nn.Module):
    """Особая свёртка"""
    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t 
        ttx =[tt, x], 1)
        return self._layer(ttx)

class ODEfunc(nn.Module):

    def __init__(self, dim): 
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0 # number of function evaluations 
        # объем вычислений, зависит от размера сетки в численном методе

    def forward(self, t, x):
        self.nfe += 1 
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out) # t участвует в свертках, тех самых concat conv
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out

class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc # класс ODEfunc описан выше
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time, rtol= TOL, atol = TOL, method  = METHOD)
        return out[1]

    def nfe(self):
        return self.odefunc.nfe

    def nfe(self, value):
        self.odefunc.nfe = value

Параметры эксперимента

TOL(tolernce), параметр для функций ODEINT и odeint_adjoint. Этот параметр позволяет адаптировать модель, изменяя его, можно получить более быструю но менее точную модель

METHOD, параметр для функций ODEINT и odeint_adjoint. Этот параметр определяет какой именно численный метод будет использоваться. От него зависит точность и время работы метода.

Параметры эксперимента

1) Чтобы удостовериться в влиянии TOL можно его увеличить и уменьшить и посмотреть соответствеено результат.

2) Также можно поменять параметр METHOD, выбрав его из SOLVERS.

3) Можно сравнить ODE Net с Res Net, посмотрев на число параметров, на скорость сходимости, для этого следует переключать параметр is_odenet.

4) Для того чтобы Проверить утверждение о крнстантной памяти при использовании odeint_adjoint, нужно изменить параметр odeint на odeint_adjoint и значительно увеличить BTCH_SZ например до 200. При таком размере BTCH_SZ и спользовании метода odeint, ноутбук скорее всего упадёт.

5) Если хочется провести точный эксперимент такой же как в статье то можно выставить число эпох (NEPOCHS) на 128, но это будет довольно долго.

In [6]:
TOL = 1e-3 # испоьзуется в odeint
SOLVERS = {'dopri5': Dopri5Solver, 'euler': Euler, 'midpoint': Midpoint, 'rk4': RK4} # Численные методы
METHOD = 'rk4'# Выбрать из 'dopri5' 'euler' 'midpoint' 'rk4'
LR = 0.1 # используется в learning_rate_with_decay
odeint = ODEINT# ODEINT или odeint_adjoint
is_odenet = True # тут может быть False, тогда будет в эксперименте участвовать resnet а не ODE-Net
BTCH_SZ = 50


In [7]:
device = 'cpu'
In [8]:
downsampling_layers = [
    nn.Conv2d(1, 64, 3, 1),
    nn.Conv2d(64, 64, 4, 2, 1),
    nn.Conv2d(64, 64, 4, 2, 1),

feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]

инициализация модели

In [9]:
# сама модель
model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)

# Для оптимизации
criterion = nn.CrossEntropyLoss().to(device)

# поучаем loader-ы
train_loader, test_loader, train_eval_loader = get_mnist_loaders(
    data_aug=True, batch_size = BTCH_SZ, test_batch_size=1000)

# Получаем бесконечный итератор

data_gen = inf_generator(train_loader)
batches_per_epoch = len(train_loader)

# Уменьшение lr от итерации

lr_fn = learning_rate_with_decay(
        batch_size = BTCH_SZ, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
        decay_rates=[1, 0.1, 0.01, 0.001]

# Оптимизатор

optimizer = torch.optim.SGD(model.parameters(), lr= LR, momentum=0.9)

# Инициализируем 3 различных RunningAverageMeter
best_acc = 0
batch_time_meter = RunningAverageMeter()
f_nfe_meter = RunningAverageMeter()
b_nfe_meter = RunningAverageMeter()
end = time.time()

число параметров

можно заметить что у ResNet их больше

In [10]:
      count_parameters(nn.Sequential(*feature_layers)), count_parameters(nn.Sequential(*fc_layers)))

print('Number of parameters: {}'.format(count_parameters(model)) )
132096 75392 778
Number of parameters: 208266

запуск эксперимента

In [11]:
# Главный цикл
for itr in tqdm(range(1, NEPOCHS * batches_per_epoch + 1)):

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_fn(itr)

        x, y = data_gen.__next__()
        x =
        y =
        logits = model(x)
        loss = criterion(logits, y)

        if is_odenet:
            nfe_forward = feature_layers[0].nfe
            feature_layers[0].nfe = 0
        if is_odenet:
            nfe_backward = feature_layers[0].nfe
            feature_layers[0].nfe = 0

        batch_time_meter.update(time.time() - end)
        if is_odenet:
        end = time.time()
        if itr % batches_per_epoch == 0:
            with torch.no_grad():
                val_acc = accuracy(model, test_loader)
                print("epoch: ", itr/ batches_per_epoch,"te_err (%) : ", 100 - val_acc)
epoch:  1.0 te_err (%) :  2.969999999999999
epoch:  2.0 te_err (%) :  0.9399999999999977
epoch:  3.0 te_err (%) :  1.2199999999999989
epoch:  4.0 te_err (%) :  0.7600000000000051
epoch:  5.0 te_err (%) :  1.0100000000000051