Данный ноутбук содержит код эксперимента по исследованию возможности модели нейродифференциального уравнения явным образом контролировать trade-off между численной точностью и вычислительными затратами.

In [1]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torchdiffeq import odeint_adjoint as odeint
In [2]:
from utils import norm, Flatten, get_mnist_loaders, one_hot, ConcatConv2d
In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
In [4]:
downsampling_layers = [
    nn.Conv2d(1, 64, 3, 1),
    norm(64),
    nn.ReLU(inplace=True),
    nn.Conv2d(64, 64, 4, 2, 1),
    norm(64),
    nn.ReLU(inplace=True),
    nn.Conv2d(64, 64, 4, 2, 1),
]
fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]
In [5]:
class ODEfunc(nn.Module):

    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out
    
class ODEBlock(nn.Module):

    def __init__(self, odefunc, tol=1e-3, method=None):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()
        self.tol = tol
        self.method = method

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time, rtol=self.tol, 
                     atol=self.tol, method=self.method)
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

Загрузим веса заранее обученной модели ODE-Net (во время обучения использовалась максимально допустимая абсолютная ошибка численного метода $tol=10^{-3}$):

In [6]:
checkpoint = torch.load('ODEnet_mnist.pth')
batch_size = test_batch_size = 1000
data_aug = False
In [7]:
all_acc = []
all_times = []
nfes = []
tols = [1e-4, 1e-3, 1e-2, 1e-1, 1]
In [8]:
for tol in tols:
    feature_layers = [ODEBlock(ODEfunc(64), tol=tol)]
    model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    train_loader, test_loader, train_eval_loader = get_mnist_loaders(data_aug, batch_size, test_batch_size)

    with torch.no_grad():
        times = []
        total_correct = 0
        for x, y in test_loader:
            x = x.to(device)
            y = one_hot(np.array(y.numpy()), 10)

            target_class = np.argmax(y, axis=1)
            start = time.time()
            preds = model(x)
            times.append(time.time() - start)
            predicted_class = np.argmax(preds.cpu().detach().numpy(), axis=1)
            total_correct += np.sum(predicted_class == target_class)
        accuracy = total_correct / len(test_loader.dataset)
        nfe = feature_layers[0].nfe
        all_acc.append(accuracy)
        nfes.append(nfe)
        all_times.append(times)
    print('Tol={0}, accuracy={1}, mean_time={2:0.2f}, nfe={3}'.format(tol, accuracy, np.mean(times), nfe))
Tol=0.0001, accuracy=0.9961, mean_time=4.60, nfe=320
Tol=0.001, accuracy=0.9961, mean_time=3.76, nfe=260
Tol=0.01, accuracy=0.996, mean_time=2.93, nfe=200
Tol=0.1, accuracy=0.9961, mean_time=2.13, nfe=140
Tol=1, accuracy=0.9956, mean_time=2.13, nfe=140