#!/usr/bin/env python # coding: utf-8 # In[1]: import os import argparse import logging import time import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import torchvision.datasets as datasets import torchvision.transforms as transforms from tqdm import tqdm_notebook as tqdm from fixed_grid import Euler, Midpoint, RK4 from dopri5 import Dopri5Solver from misc import _decreasing, _check_inputs, _flatten, _flatten_convert_none_to_zeros # **ODEINT и odeint_adjoint** # # Функции ODEINT и odeint_adjoint применяют численные методы для вычисления выхода модели. Причем в odeint_adjoint используются сопряженные переменные, поэтому объем его используемой памяти - константа O(1), в отличие от ODEINT. # In[2]: def ODEINT(func, y0, t, rtol=1e-7, atol=1e-9, method=None, options=None): """Integrate a system of ordinary differential equations. Solves the initial value problem for a non-stiff system of first order ODEs: ``` dy/dt = func(t, y), y(t[0]) = y0 ``` where y is a Tensor of any shape. Output dtypes and numerical precision are based on the dtypes of the inputs `y0`. Args: func: Function that maps a Tensor holding the state `y` and a scalar Tensor `t` into a Tensor of state derivatives with respect to time. y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May have any floating point or complex dtype. t: 1-D Tensor holding a sequence of time points for which to solve for `y`. The initial time point should be the first element of this sequence, and each time must be larger than the previous time. May have any floating point dtype. Converted to a Tensor with float64 dtype. rtol: optional float64 Tensor specifying an upper bound on relative error, per element of `y`. atol: optional float64 Tensor specifying an upper bound on absolute error, per element of `y`. method: optional string indicating the integration method to use. options: optional dict of configuring options for the indicated integration method. Can only be provided if a `method` is explicitly set. name: Optional name for this operation. Returns: y: Tensor, where the first dimension corresponds to different time points. Contains the solved value of y for each desired time point in `t`, with the initial value `y0` being the first element along the first dimension. Raises: ValueError: if an invalid `method` is provided. TypeError: if `options` is supplied without `method`, or if `t` or `y0` has an invalid dtype. """ tensor_input, func, y0, t = _check_inputs(func, y0, t) if options is None: options = {} elif method is None: raise ValueError('cannot supply `options` without specifying `method`') if method is None: method = 'dopri5' # ЧМЫ solver = SOLVERS[method](func, y0, rtol=rtol, atol=atol, **options) solution = solver.integrate(t) if tensor_input: solution = solution[0] return solution class OdeintAdjointMethod(torch.autograd.Function): @staticmethod def forward(ctx, *args): assert len(args) >= 8, 'Internal error: all arguments required.' y0, func, t, flat_params, rtol, atol, method, options = \ args[:-7], args[-7], args[-6], args[-5], args[-4], args[-3], args[-2], args[-1] ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options = func, rtol, atol, method, options with torch.no_grad(): ans = ODEINT(func, y0, t, rtol=rtol, atol=atol, method=method, options=options) ctx.save_for_backward(t, flat_params, *ans) return ans @staticmethod def backward(ctx, *grad_output): t, flat_params, *ans = ctx.saved_tensors ans = tuple(ans) func, rtol, atol, method, options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options n_tensors = len(ans) f_params = tuple(func.parameters()) # TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives. def augmented_dynamics(t, y_aug): # Dynamics of the original system augmented with # the adjoint wrt y, and an integrator wrt t and args. y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. with torch.set_grad_enabled(True): t = t.to(y[0].device).detach().requires_grad_(True) y = tuple(y_.detach().requires_grad_(True) for y_ in y) func_eval = func(t, y) vjp_t, *vjp_y_and_params = torch.autograd.grad( func_eval, (t,) + y + f_params, tuple(-adj_y_ for adj_y_ in adj_y), allow_unused=True, retain_graph=True ) vjp_y = vjp_y_and_params[:n_tensors] vjp_params = vjp_y_and_params[n_tensors:] # autograd.grad returns None if no gradient, set to zero. vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t vjp_y = tuple(torch.zeros_like(y_) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y)) vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params) if len(f_params) == 0: vjp_params = torch.tensor(0.).to(vjp_y[0]) return (*func_eval, *vjp_y, vjp_t, vjp_params) T = ans[0].shape[0] with torch.no_grad(): adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output) adj_params = torch.zeros_like(flat_params) adj_time = torch.tensor(0.).to(t) time_vjps = [] for i in range(T - 1, 0, -1): ans_i = tuple(ans_[i] for ans_ in ans) grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output) func_i = func(t[i], ans_i) # Compute the effect of moving the current time measurement point. dLd_cur_t = sum( torch.dot(func_i_.reshape(-1), grad_output_i_.reshape(-1)).reshape(1) for func_i_, grad_output_i_ in zip(func_i, grad_output_i) ) adj_time = adj_time - dLd_cur_t time_vjps.append(dLd_cur_t) # Run the augmented system backwards in time. if adj_params.numel() == 0: adj_params = torch.tensor(0.).to(adj_y[0]) aug_y0 = (*ans_i, *adj_y, adj_time, adj_params) aug_ans = ODEINT( augmented_dynamics, aug_y0, torch.tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options ) # Unpack aug_ans. adj_y = aug_ans[n_tensors:2 * n_tensors] adj_time = aug_ans[2 * n_tensors] adj_params = aug_ans[2 * n_tensors + 1] adj_y = tuple(adj_y_[1] if len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y) if len(adj_time) > 0: adj_time = adj_time[1] if len(adj_params) > 0: adj_params = adj_params[1] adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output)) del aug_y0, aug_ans time_vjps.append(adj_time) time_vjps = torch.cat(time_vjps[::-1]) return (*adj_y, None, time_vjps, adj_params, None, None, None, None, None) def odeint_adjoint(func, y0, t, rtol=1e-6, atol=1e-12, method=None, options=None): # We need this in order to access the variables inside this module, # since we have no other way of getting variables along the execution path. if not isinstance(func, nn.Module): raise ValueError('func is required to be an instance of nn.Module.') tensor_input = False if torch.is_tensor(y0): class TupleFunc(nn.Module): def __init__(self, base_func): super(TupleFunc, self).__init__() self.base_func = base_func def forward(self, t, y): return (self.base_func(t, y[0]),) tensor_input = True y0 = (y0,) func = TupleFunc(func) flat_params = _flatten(func.parameters()) ys = OdeintAdjointMethod.apply(*y0, func, t, flat_params, rtol, atol, method, options) if tensor_input: ys = ys[0] return ys # Вспомогательные функции # In[3]: class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() def forward(self, x): shape = torch.prod(torch.tensor(x.shape[1:])).item() return x.view(-1, shape) class RunningAverageMeter(object): """Computes and stores the average and current value""" def __init__(self, momentum=0.99): self.momentum = momentum self.reset() def reset(self): self.val = None self.avg = 0 def update(self, val): if self.val is None: self.avg = val else: self.avg = self.avg * self.momentum + val * (1 - self.momentum) self.val = val def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0): # загружает данные if data_aug: transform_train = transforms.Compose([ transforms.RandomCrop(28, padding=4), transforms.ToTensor(), ]) else: transform_train = transforms.Compose([ transforms.ToTensor(), ]) transform_test = transforms.Compose([ transforms.ToTensor(), ]) train_loader = DataLoader( datasets.MNIST(root='data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True ) train_eval_loader = DataLoader( datasets.MNIST(root='data/mnist', train=True, download=True, transform=transform_test), batch_size=test_batch_size, shuffle=False, num_workers=0, drop_last=True ) test_loader = DataLoader( datasets.MNIST(root='data/mnist', train=False, download=True, transform=transform_test), batch_size=test_batch_size, shuffle=False, num_workers=0, drop_last=True ) return train_loader, test_loader, train_eval_loader #loader def inf_generator(iterable): """Allows training with DataLoaders in a single infinite loop: for i, (x, y) in enumerate(inf_generator(train_loader)): """ iterator = iterable.__iter__() while True: try: yield iterator.__next__() except StopIteration: iterator = iterable.__iter__() def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates): # реализует затухание lr initial_learning_rate = LR * batch_size / batch_denom boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs] vals = [initial_learning_rate * decay for decay in decay_rates] def learning_rate_fn(itr): lt = [itr < b for b in boundaries] + [True] i = np.argmax(lt) return vals[i] return learning_rate_fn def one_hot(x, K): #one hot кодирование return np.array(x[:, None] == np.arange(K)[None, :], dtype=int) def accuracy(model, dataset_loader): total_correct = 0 for x, y in dataset_loader: x = x.to(device) y = one_hot(np.array(y.numpy()), 10) target_class = np.argmax(y, axis=1) predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1) total_correct += np.sum(predicted_class == target_class) return (total_correct / len(dataset_loader.dataset)) * 100 # просто точность def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) # число параметров модели,что показывает её сложность def makedirs(dirname): # создаёт дирректорию if not os.path.exists(dirname): os.makedirs(dirname) def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) def norm(dim): """нормализация: https://arxiv.org/pdf/1803.08494.pdf""" return nn.GroupNorm(min(32, dim), dim) # **Res** блоки задают архитектуру остаточной сети # In[4]: class ResBlock(nn.Module): """Блок для ResNet""" expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(ResBlock, self).__init__() self.norm1 = norm(inplanes) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.conv1 = conv3x3(inplanes, planes, stride) self.norm2 = norm(planes) self.conv2 = conv3x3(planes, planes) def forward(self, x): shortcut = x out = self.relu(self.norm1(x)) if self.downsample is not None: shortcut = self.downsample(out) out = self.conv1(out) out = self.norm2(out) out = self.relu(out) out = self.conv2(out) return out + shortcut # F(x) + x # **ODE** блоки Задают архитектуру ODE сети # In[5]: class ConcatConv2d(nn.Module): """Особая свёртка""" def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): super(ConcatConv2d, self).__init__() module = nn.ConvTranspose2d if transpose else nn.Conv2d self._layer = module( dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias ) def forward(self, t, x): tt = torch.ones_like(x[:, :1, :, :]) * t ttx = torch.cat([tt, x], 1) return self._layer(ttx) class ODEfunc(nn.Module): def __init__(self, dim): super(ODEfunc, self).__init__() self.norm1 = norm(dim) self.relu = nn.ReLU(inplace=True) self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1) self.norm2 = norm(dim) self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1) self.norm3 = norm(dim) self.nfe = 0 # number of function evaluations # объем вычислений, зависит от размера сетки в численном методе def forward(self, t, x): self.nfe += 1 out = self.norm1(x) out = self.relu(out) out = self.conv1(t, out) # t участвует в свертках, тех самых concat conv out = self.norm2(out) out = self.relu(out) out = self.conv2(t, out) out = self.norm3(out) return out class ODEBlock(nn.Module): def __init__(self, odefunc): super(ODEBlock, self).__init__() self.odefunc = odefunc # класс ODEfunc описан выше self.integration_time = torch.tensor([0, 1]).float() def forward(self, x): self.integration_time = self.integration_time.type_as(x) out = odeint(self.odefunc, x, self.integration_time, rtol= TOL, atol = TOL, method = METHOD) return out[1] @property def nfe(self): return self.odefunc.nfe @nfe.setter def nfe(self, value): self.odefunc.nfe = value # # Параметры эксперимента # # **TOL**(tolernce), параметр для функций **ODEINT** и **odeint_adjoint**. Этот параметр позволяет адаптировать модель, изменяя его, можно получить более быструю но менее точную модель # **METHOD**, параметр для функций **ODEINT** и **odeint_adjoint**. Этот параметр определяет какой именно численный метод будет использоваться. От него зависит точность и время работы метода. # **Параметры эксперимента** # # 1) Чтобы удостовериться в влиянии **TOL** можно его увеличить и уменьшить и посмотреть соответствеено результат. # # # 2) Также можно поменять параметр **METHOD**, выбрав его из **SOLVERS**. # # # 3) Можно сравнить **ODE Net** с **Res Net**, посмотрев на число параметров, на скорость сходимости, для этого следует переключать параметр **is_odenet**. # # 4) Для того чтобы Проверить утверждение о крнстантной памяти при использовании odeint_adjoint, нужно изменить параметр **odeint** на **odeint_adjoint** и значительно увеличить **BTCH_SZ** например до 200. При таком размере **BTCH_SZ** и спользовании метода **odeint**, ноутбук скорее всего упадёт. # # 5) Если хочется провести точный эксперимент такой же как в статье то можно выставить число эпох (**NEPOCHS**) на 128, но это будет довольно долго. # In[6]: TOL = 1e-3 # испоьзуется в odeint SOLVERS = {'dopri5': Dopri5Solver, 'euler': Euler, 'midpoint': Midpoint, 'rk4': RK4} # Численные методы METHOD = 'rk4'# Выбрать из 'dopri5' 'euler' 'midpoint' 'rk4' LR = 0.1 # используется в learning_rate_with_decay odeint = ODEINT# ODEINT или odeint_adjoint is_odenet = True # тут может быть False, тогда будет в эксперименте участвовать resnet а не ODE-Net BTCH_SZ = 50 NEPOCHS = 5 # # Эксперимент # In[7]: makedirs("./experiment") device = 'cpu' # In[8]: downsampling_layers = [ nn.Conv2d(1, 64, 3, 1), norm(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 4, 2, 1), norm(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 4, 2, 1), ] feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)] fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)] # инициализация модели # In[9]: # сама модель model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device) # Для оптимизации criterion = nn.CrossEntropyLoss().to(device) # поучаем loader-ы train_loader, test_loader, train_eval_loader = get_mnist_loaders( data_aug=True, batch_size = BTCH_SZ, test_batch_size=1000) # Получаем бесконечный итератор data_gen = inf_generator(train_loader) batches_per_epoch = len(train_loader) # Уменьшение lr от итерации lr_fn = learning_rate_with_decay( batch_size = BTCH_SZ, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140], decay_rates=[1, 0.1, 0.01, 0.001] ) # Оптимизатор optimizer = torch.optim.SGD(model.parameters(), lr= LR, momentum=0.9) # Инициализируем 3 различных RunningAverageMeter best_acc = 0 batch_time_meter = RunningAverageMeter() f_nfe_meter = RunningAverageMeter() b_nfe_meter = RunningAverageMeter() end = time.time() # **число параметров** # # можно заметить что у ResNet их больше # In[10]: print(count_parameters(nn.Sequential(*downsampling_layers)), count_parameters(nn.Sequential(*feature_layers)), count_parameters(nn.Sequential(*fc_layers))) print('Number of parameters: {}'.format(count_parameters(model)) ) # **запуск эксперимента** # In[11]: # Главный цикл for itr in tqdm(range(1, NEPOCHS * batches_per_epoch + 1)): for param_group in optimizer.param_groups: param_group['lr'] = lr_fn(itr) optimizer.zero_grad() x, y = data_gen.__next__() x = x.to(device) y = y.to(device) logits = model(x) loss = criterion(logits, y) if is_odenet: nfe_forward = feature_layers[0].nfe feature_layers[0].nfe = 0 loss.backward() optimizer.step() if is_odenet: nfe_backward = feature_layers[0].nfe feature_layers[0].nfe = 0 batch_time_meter.update(time.time() - end) if is_odenet: f_nfe_meter.update(nfe_forward) b_nfe_meter.update(nfe_backward) end = time.time() if itr % batches_per_epoch == 0: with torch.no_grad(): val_acc = accuracy(model, test_loader) print("epoch: ", itr/ batches_per_epoch,"te_err (%) : ", 100 - val_acc)