import matplotlib.pyplot as plt
import pyro
import torch
import torch.nn as nn
import pyro.distributions as dist
from pyro.contrib.easyguide import easy_guide, EasyGuide
from pyro.nn import PyroModule, PyroSample, PyroParam
from pyro.distributions import constraints
import numpy as np
from tqdm import tqdm
torch.manual_seed(42)
pyro.set_rng_seed(42)
pyro.__version__
'1.6.0'
# Setup data
x_data = np.linspace(0, 10, 100)
ep = 0.5 * np.random.randn(x_data.shape[0])
y_data = 5*x_data + 0.1 + ep
x_data = x_data[:, None]
y_data = y_data
x_data = torch.tensor(x_data).type(torch.float32)
y_data = torch.tensor(y_data).type(torch.float32)
print(x_data.shape, y_data.shape)
torch.Size([100, 1]) torch.Size([100])
class BayesianRegression(PyroModule):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = PyroModule[nn.Linear](in_features, out_features)
self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))
def forward(self, x, full_size, y=None):
sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
mean = self.linear(x).squeeze(-1)
# since I am passing the x and y from a batch, I think I need to pass x.shape[0] to
# subsample_size as a proxy for batch_size. And do not use the random indices from pyro.plate
# context manager because I already have the data at hand. Am I right thinking this way?
with pyro.plate("data", size=full_size, subsample_size=x.shape[0]):
obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
return mean
in_features = 1
out_features = 1
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
class Dataset(torch.utils.data.Dataset):
def __init__(self, x, y):
super().__init__()
self.x = x
self.y = y
def __len__(self):
return self.x.shape[0]
def __getitem__(self, index):
return self.x[index], self.y[index]
dataset = Dataset(x_data, y_data)
def train(model, guide, X, Y, dataset, adam_params, n_epochs=5000):
# X, Y used only for parameter initialization
pyro.clear_param_store()
torch.manual_seed(42)
pyro.set_rng_seed(42)
# Get params
with pyro.poutine.block(), pyro.poutine.trace(param_only=True) as param_capture:
guide(x=X, full_size=X.shape[0], y=Y)
params = list([pyro.param(name).unconstrained() for name in param_capture.trace])
# Train
optimizer = torch.optim.Adam(params, **adam_params)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
losses = []
for epoch in tqdm(range(n_epochs)):
epoch_loss = []
for batch in torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True):
x, y = batch
loss = loss_fn(model, guide, x, X.shape[0], y)
epoch_loss.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(sum(epoch_loss) / len(epoch_loss))
plt.plot(losses)
class RegressionGuideAsClass(EasyGuide):
def __init__(self, model):
super().__init__(model)
def guide(self, x, full_size, y=None):
group = self.group(match=".*")
loc = pyro.param("loc", torch.randn(group.event_shape))
scale = pyro.param("scale", torch.ones(group.event_shape)*0.01, constraint=constraints.positive)
group.sample("joint", dist.Normal(loc=loc, scale=scale).to_event(1))
base_regression_model = BayesianRegression(in_features, out_features)
regression_guide_as_class = RegressionGuideAsClass(base_regression_model)
train(base_regression_model, regression_guide_as_class, x_data, y_data, dataset, adam_params)
100%|██████████| 5000/5000 [00:53<00:00, 93.05it/s]
@easy_guide(base_regression_model)
def regression_guide_with_decorator(self, x, full_size, y=None):
group = self.group(match=".*")
loc = pyro.param("loc", torch.randn(group.event_shape))
scale = pyro.param("scale", torch.ones(group.event_shape)*0.01, constraint=constraints.positive)
group.sample("joint", dist.Normal(loc=loc, scale=scale).to_event(1))
base_regression_model = BayesianRegression(in_features, out_features)
train(base_regression_model, regression_guide_with_decorator, x_data, y_data, dataset, adam_params)
100%|██████████| 5000/5000 [00:54<00:00, 91.90it/s]
base_regression_model = BayesianRegression(in_features, out_features)
auto_guide = pyro.infer.autoguide.AutoNormal(base_regression_model)
train(base_regression_model, auto_guide, x_data, y_data, dataset, adam_params)
100%|██████████| 5000/5000 [01:12<00:00, 69.17it/s]