#!/usr/bin/env python # coding: utf-8 # In[1]: import matplotlib.pyplot as plt import pyro import torch import torch.nn as nn import pyro.distributions as dist from pyro.contrib.easyguide import easy_guide, EasyGuide from pyro.nn import PyroModule, PyroSample, PyroParam from pyro.distributions import constraints import numpy as np from tqdm import tqdm torch.manual_seed(42) pyro.set_rng_seed(42) pyro.__version__ # In[2]: # Setup data x_data = np.linspace(0, 10, 100) ep = 0.5 * np.random.randn(x_data.shape[0]) y_data = 5*x_data + 0.1 + ep x_data = x_data[:, None] y_data = y_data x_data = torch.tensor(x_data).type(torch.float32) y_data = torch.tensor(y_data).type(torch.float32) print(x_data.shape, y_data.shape) # In[3]: class BayesianRegression(PyroModule): def __init__(self, in_features, out_features): super().__init__() self.linear = PyroModule[nn.Linear](in_features, out_features) self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)) self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1)) def forward(self, x, full_size, y=None): sigma = pyro.sample("sigma", dist.Uniform(0., 10.)) mean = self.linear(x).squeeze(-1) # since I am passing the x and y from a batch, I think I need to pass x.shape[0] to # subsample_size as a proxy for batch_size. And do not use the random indices from pyro.plate # context manager because I already have the data at hand. Am I right thinking this way? with pyro.plate("data", size=full_size, subsample_size=x.shape[0]): obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y) return mean # In[4]: in_features = 1 out_features = 1 adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)} # ## Base Setup for custom loop # In[5]: class Dataset(torch.utils.data.Dataset): def __init__(self, x, y): super().__init__() self.x = x self.y = y def __len__(self): return self.x.shape[0] def __getitem__(self, index): return self.x[index], self.y[index] dataset = Dataset(x_data, y_data) # In[6]: def train(model, guide, X, Y, dataset, adam_params, n_epochs=5000): # X, Y used only for parameter initialization pyro.clear_param_store() torch.manual_seed(42) pyro.set_rng_seed(42) # Get params with pyro.poutine.block(), pyro.poutine.trace(param_only=True) as param_capture: guide(x=X, full_size=X.shape[0], y=Y) params = list([pyro.param(name).unconstrained() for name in param_capture.trace]) # Train optimizer = torch.optim.Adam(params, **adam_params) loss_fn = pyro.infer.Trace_ELBO().differentiable_loss losses = [] for epoch in tqdm(range(n_epochs)): epoch_loss = [] for batch in torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True): x, y = batch loss = loss_fn(model, guide, x, X.shape[0], y) epoch_loss.append(loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() losses.append(sum(epoch_loss) / len(epoch_loss)) plt.plot(losses) # ## Using EasyGuide class in custom loop # In[7]: class RegressionGuideAsClass(EasyGuide): def __init__(self, model): super().__init__(model) def guide(self, x, full_size, y=None): group = self.group(match=".*") loc = pyro.param("loc", torch.randn(group.event_shape)) scale = pyro.param("scale", torch.ones(group.event_shape)*0.01, constraint=constraints.positive) group.sample("joint", dist.Normal(loc=loc, scale=scale).to_event(1)) # In[8]: base_regression_model = BayesianRegression(in_features, out_features) regression_guide_as_class = RegressionGuideAsClass(base_regression_model) train(base_regression_model, regression_guide_as_class, x_data, y_data, dataset, adam_params) # ## Using easy_guide decorator in custom loop # In[9]: @easy_guide(base_regression_model) def regression_guide_with_decorator(self, x, full_size, y=None): group = self.group(match=".*") loc = pyro.param("loc", torch.randn(group.event_shape)) scale = pyro.param("scale", torch.ones(group.event_shape)*0.01, constraint=constraints.positive) group.sample("joint", dist.Normal(loc=loc, scale=scale).to_event(1)) # In[10]: base_regression_model = BayesianRegression(in_features, out_features) train(base_regression_model, regression_guide_with_decorator, x_data, y_data, dataset, adam_params) # ## Using AutoNormal # In[11]: base_regression_model = BayesianRegression(in_features, out_features) auto_guide = pyro.infer.autoguide.AutoNormal(base_regression_model) train(base_regression_model, auto_guide, x_data, y_data, dataset, adam_params)