#!/usr/bin/env python # coding: utf-8 # Данный ноутбук содержит код эксперимента по исследованию возможности модели нейродифференциального уравнения явным образом контролировать trade-off между численной точностью и вычислительными затратами. # In[1]: import os import time import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import torchvision.datasets as datasets import torchvision.transforms as transforms from torchdiffeq import odeint_adjoint as odeint # In[2]: from utils import norm, Flatten, get_mnist_loaders, one_hot, ConcatConv2d # In[3]: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # In[4]: downsampling_layers = [ nn.Conv2d(1, 64, 3, 1), norm(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 4, 2, 1), norm(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 4, 2, 1), ] fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)] # In[5]: class ODEfunc(nn.Module): def __init__(self, dim): super(ODEfunc, self).__init__() self.norm1 = norm(dim) self.relu = nn.ReLU(inplace=True) self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1) self.norm2 = norm(dim) self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1) self.norm3 = norm(dim) self.nfe = 0 def forward(self, t, x): self.nfe += 1 out = self.norm1(x) out = self.relu(out) out = self.conv1(t, out) out = self.norm2(out) out = self.relu(out) out = self.conv2(t, out) out = self.norm3(out) return out class ODEBlock(nn.Module): def __init__(self, odefunc, tol=1e-3, method=None): super(ODEBlock, self).__init__() self.odefunc = odefunc self.integration_time = torch.tensor([0, 1]).float() self.tol = tol self.method = method def forward(self, x): self.integration_time = self.integration_time.type_as(x) out = odeint(self.odefunc, x, self.integration_time, rtol=self.tol, atol=self.tol, method=self.method) return out[1] @property def nfe(self): return self.odefunc.nfe @nfe.setter def nfe(self, value): self.odefunc.nfe = value # Загрузим веса заранее обученной модели ODE-Net (во время обучения использовалась максимально допустимая абсолютная ошибка численного метода $tol=10^{-3}$): # In[6]: checkpoint = torch.load('ODEnet_mnist.pth') batch_size = test_batch_size = 1000 data_aug = False # In[7]: all_acc = [] all_times = [] nfes = [] tols = [1e-4, 1e-3, 1e-2, 1e-1, 1] # In[8]: for tol in tols: feature_layers = [ODEBlock(ODEfunc(64), tol=tol)] model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device) model.load_state_dict(checkpoint['state_dict']) model.eval() train_loader, test_loader, train_eval_loader = get_mnist_loaders(data_aug, batch_size, test_batch_size) with torch.no_grad(): times = [] total_correct = 0 for x, y in test_loader: x = x.to(device) y = one_hot(np.array(y.numpy()), 10) target_class = np.argmax(y, axis=1) start = time.time() preds = model(x) times.append(time.time() - start) predicted_class = np.argmax(preds.cpu().detach().numpy(), axis=1) total_correct += np.sum(predicted_class == target_class) accuracy = total_correct / len(test_loader.dataset) nfe = feature_layers[0].nfe all_acc.append(accuracy) nfes.append(nfe) all_times.append(times) print('Tol={0}, accuracy={1}, mean_time={2:0.2f}, nfe={3}'.format(tol, accuracy, np.mean(times), nfe))