import os
import argparse
import logging
import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Устанавливаем библиотеку с сайта авторов: https://github.com/rtqichen/torchdiffeq
from torchdiffeq import odeint
Создаём классы ODEfunc, реализующий обучаемый модуль $f(\cdot)$, из уравнения $\frac{dz}{dt} = f(z(t), t, \theta)$. В нашем случае, это будет простая трёхслойная полносвязанная нейросеть. Первый слой увеличивает размерность пространства, второй -- содержит основное число параметров, третий -- проецирует скрытое представление обратно в пространство малой размерности.
При этом, в слои добавлена зависимость от $t$, как это требуется в исходной функции.
class ODEfunc(nn.Module):
def __init__(self, dim, hidden_dim):
super(ODEfunc, self).__init__()
self.first = nn.Linear(dim, hidden_dim)
self.second = nn.Linear(hidden_dim + 1, hidden_dim)
self.third = nn.Linear(hidden_dim + 1, dim)
def forward(self, t, x):
out = self.first(x)
times = torch.ones_like(x) * t
cat_inp = torch.cat((out, times), dim=1)
out = self.second(cat_inp)
out = F.relu(out)
times = torch.ones_like(x) * t
out = F.relu(out)
cat_inp = torch.cat((out, times), dim=1)
out = self.third(cat_inp)
return out
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
Создаём сеть и считаем количество параметров:
network = ODEBlock(ODEfunc(dim=1, hidden_dim=500))
count_parameters(network)
252502
Обучаем архитектуру. Возьмём синтетические данные - случайные числа от 0 до 2, будем предсказывать их квадрат. Заметим, что
optimizer = torch.optim.Adam(params=network.parameters(), lr=0.000001)
for i in range(20000):
# генерируем данные
batch = np.random.sample(size=(400, 1)) * 2
values = batch ** 2
batch = torch.tensor(batch, dtype=torch.float32)
values = torch.tensor(values, dtype=torch.float32)
# считаем значения и функцию потерь
predictions = network(batch, tol=1e-3)
loss = F.mse_loss(input=predictions, target=values)
# считаем градиент ошибок по папарметрам и делаем шаг в направлении антиградиента
loss.backward()
optimizer.step()
if i % 100 == 0:
print("MSE Loss (iter {}): {:.3f}".format(i, float(loss)))
x = np.linspace(0, 2, 40)
out = network(torch.tensor(x.reshape(-1, 1), dtype=torch.float32))
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('ggplot')
plt.figure(figsize=(6, 6))
plt.plot(x, out.detach().numpy(), label='NeuralODE')
plt.plot(x, x**2, label='golden')
plt.legend()
plt.show()