This tutorial describes a workflow for incrementally building pipelines to analyze high-dimensional data in Pyro. This workflow has evolved over a few years of applying Pyro to models with $10^5$ or more latent variables. We build on Gelman et al. (2020)'s concept of Bayesian workflow, and focus on aspects particular to high-dimensional models: approximate inference and numerical stability. While the individual components of the pipeline deserve their own tutorials, this tutorial focuses on incrementally combining those components.
The fastest way to find a good model of your data is to quickly discard many bad models, i.e. to iterate. In statistics we call this iterative workflow Box's loop. An efficient workflow allows us to discard bad models as quickly as possible. Workflow efficiency demands that code changes to upstream components don't break previous coding effort on downstream components. Pyro's approaches to this challenge include strategies for variational approximations (pyro.infer.autoguide) and strategies for transforming model coordinate systems to improve geometry (pyro.infer.reparam).
Consider the problem of sampling from the posterior distribution of a probabilistic model with $10^5$ or more continuous latent variables, but whose data fits entirely in memory. (For larger datasets, consider [amortized variational inference](http://pyro.ai/examples/svi_part_ii.html).) Inference in such high-dimensional models can be challenging even when posteriors are known to be [unimodal](https://en.wikipedia.org/wiki/Unimodality) or even [log-concave](https://arxiv.org/abs/1404.5886), due to correlations among latent variables.
To perform inference in such high-dimensional models in Pyro, we have evolved a [workflow](https://arxiv.org/abs/2011.01808) to incrementally build data analysis pipelines combining variational inference, reparametrization effects, and ad-hoc initialization strategies. Our workflow is summarized as a sequence of steps, where validation after any step might suggest backtracking to change design decisions at a previous step.
The crux of efficient workflow is to ensure changes don't break your pipeline. That is, after you build a number of pipeline stages, validate results, and decide to change one component in the pipeline, you'd like to minimize code changes needed in other components. The remainder of this tutorial describes these steps individually, then describes nuances of interactions among stages, then provides an example.
The running example in this tutorial will be a model [(Obermeyer et al. 2022)](https://www.medrxiv.org/content/10.1101/2021.09.07.21263228v2) of the relative growth rates of different strains of the SARS-CoV-2 virus, based on [open data](https://docs.nextstrain.org/projects/ncov/en/latest/reference/remote_inputs.html) counting different [PANGO lineages](https://cov-lineages.org/) of viral genomic samples collected at different times around the world. There are about 2 million sequences in total.
The model is a high-dimensional regression model with around 1000 coefficients, a multivariate logistic growth function (using a simple [torch.softmax()](https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html)) and a [Multinomial](https://pytorch.org/docs/stable/distributions.html#multinomial) likelihood. While the number of coefficients is relatively small, there are about 500,000 local latent variables to estimate, and plate structure in the model should lead to an approximately block diagonal posterior covariance matrix. For an introduction to simple logistic growth models using this same dataset, see the [logistic growth tutorial](logistic-growth.html).
from collections import defaultdict
from pprint import pprint
import functools
import math
import os
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions import constraints
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import (
AutoDelta,
AutoNormal,
AutoMultivariateNormal,
AutoLowRankMultivariateNormal,
AutoGuideList,
init_to_feasible,
)
from pyro.infer.reparam import AutoReparam, LocScaleReparam
from pyro.nn.module import PyroParam
from pyro.optim import ClippedAdam
from pyro.ops.special import sparse_multinomial_likelihood
import matplotlib.pyplot as plt
if torch.cuda.is_available():
print("Using GPU")
torch.set_default_tensor_type("torch.cuda.FloatTensor")
else:
print("Using CPU")
smoke_test = ('CI' in os.environ)
Using CPU
Our running example will use a pre-cleaned dataset. We started with Nextstrain's [ncov](https://docs.nextstrain.org/projects/ncov/en/latest/reference/remote_inputs.html) tool for preprocessing, followed by the Broad Institute's [pyro-cov](https://github.com/broadinstitute/pyro-cov/blob/master/scripts/preprocess_nextstrain.py) tool for aggregation, resulting in a dataset of SARS-CoV-2 lineages observed around the world through time.
from pyro.contrib.examples.nextstrain import load_nextstrain_counts
dataset = load_nextstrain_counts()
def summarize(x, name=""):
if isinstance(x, dict):
for k, v in sorted(x.items()):
summarize(v, name + "." + k if name else k)
elif isinstance(x, torch.Tensor):
print(f"{name}: {type(x).__name__} of shape {tuple(x.shape)} on {x.device}")
elif isinstance(x, list):
print(f"{name}: {type(x).__name__} of length {len(x)}")
else:
print(f"{name}: {type(x).__name__}")
summarize(dataset)
counts: Tensor of shape (27, 202, 1316) on cpu features: Tensor of shape (1316, 2634) on cpu lineages: list of length 1316 locations: list of length 202 mutations: list of length 2634 sparse_counts.index: Tensor of shape (3, 57129) on cpu sparse_counts.total: Tensor of shape (27, 202) on cpu sparse_counts.value: Tensor of shape (57129,) on cpu start_date: datetime time_step_days: int
The first step to using Pyro is creating a generative model, either a python function or a pyro.nn.Module. Start simple. Start with a shallow hierarchy and later add latent variables to share statistical strength. Start with a slice of your data then add a plate over multiple slices. Start with simple distributions like Normal, LogNormal, Poisson and Multinomial, then consider overdispersed versions like StudentT, Gamma, GammaPoisson/NegativeBinomial, and DirichletMultinomial. Keep your model simple and readable so you can share it and get feedback from domain experts. Use weakly informative priors.
We'll focus on a multivariate logistic growth model of competing SARS-CoV-2 strains, as described in Obermeyer et al. (2022). This model uses a numerically stable logits
parameter in its multinomial likelihood, rather than a probs
parameter. Similarly upstream variables init
, rate
, rate_loc
, and coef
are all in log-space. This will mean e.g. that a zero coefficient has multiplicative effect of 1.0, and a positive coefficient has multiplicative effect greater than 1.
Note we scale coef
by 1/100 because we want to model a very small number, but the automatic parts of Pyro and PyTorch work best for numbers on the order of 1.0 rather than very small numbers. When we later interpret coef
in a volcano plot we'll need to duplicate this scaling factor.
def model(dataset):
features = dataset["features"]
counts = dataset["counts"]
assert features.shape[0] == counts.shape[-1]
S, M = features.shape
T, P, S = counts.shape
time = torch.arange(float(T)) * dataset["time_step_days"] / 5.5
time -= time.mean()
strain_plate = pyro.plate("strain", S, dim=-1)
place_plate = pyro.plate("place", P, dim=-2)
time_plate = pyro.plate("time", T, dim=-3)
# Model each region as multivariate logistic growth.
rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))
init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))
with pyro.plate("mutation", M, dim=-1):
coef = pyro.sample("coef", dist.Laplace(0, 0.5))
with strain_plate:
rate_loc = pyro.deterministic("rate_loc", 0.01 * coef @ features.T)
with place_plate, strain_plate:
rate = pyro.sample("rate", dist.Normal(rate_loc, rate_scale))
init = pyro.sample("init", dist.Normal(0, init_scale))
logits = init + rate * time[:, None, None]
# Observe sequences via a multinomial likelihood.
with time_plate, place_plate:
pyro.sample(
"obs",
dist.Multinomial(logits=logits.unsqueeze(-2), validate_args=False),
obs=counts.unsqueeze(-2),
)
The execution cost of this model is dominated by the multinomial likelihood over a large sparse count matrix.
print("counts has {:d} / {} nonzero elements".format(
dataset['counts'].count_nonzero(), dataset['counts'].numel()
))
counts has 57129 / 7177464 nonzero elements
To speed up inference (and model iteration!) we'll replace the pyro.sample(..., Multinomial)
likelihood with an equivalent but much cheaper pyro.factor
statement using a helper pyro.ops.sparse_multinomial_likelihood
.
def model(dataset, predict=None):
features = dataset["features"]
counts = dataset["counts"]
sparse_counts = dataset["sparse_counts"]
assert features.shape[0] == counts.shape[-1]
S, M = features.shape
T, P, S = counts.shape
time = torch.arange(float(T)) * dataset["time_step_days"] / 5.5
time -= time.mean()
# Model each region as multivariate logistic growth.
rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))
init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))
with pyro.plate("mutation", M, dim=-1):
coef = pyro.sample("coef", dist.Laplace(0, 0.5))
with pyro.plate("strain", S, dim=-1):
rate_loc = pyro.deterministic("rate_loc", 0.01 * coef @ features.T)
with pyro.plate("place", P, dim=-2):
rate = pyro.sample("rate", dist.Normal(rate_loc, rate_scale))
init = pyro.sample("init", dist.Normal(0, init_scale))
if predict is not None: # Exit early during evaluation.
probs = (init + rate * time[predict]).softmax(-1)
return probs
logits = (init + rate * time[:, None, None]).log_softmax(-1)
# Observe sequences via a cheap sparse multinomial likelihood.
t, p, s = sparse_counts["index"]
pyro.factor(
"obs",
sparse_multinomial_likelihood(
sparse_counts["total"], logits[t, p, s], sparse_counts["value"]
)
)
Mean field Normal inference is cheap and robust, and is a good way to sanity check your posterior point estimate, even if the posterior uncertainty may be implausibly narrow. We recommend starting with an AutoNormal guide, and possibly setting init_scale
to a small value like init_scale=0.01
or init_scale=0.001
.
Note that while MAP estimating via AutoDelta is even cheaper and more robust than mean-field AutoNormal
, AutoDelta
is coordinate-system dependent and is not invariant to reparametrization. Because in our experience most models benefit from some reparameterization, we recommend AutoNormal
over AutoDelta
because AutoNormal
is less sensitive to reparametrization (AutoDelta
can give incorrect results in some reparametrized models).
def fit_svi(model, guide, lr=0.01, num_steps=1001, log_every=100, plot=True):
pyro.clear_param_store()
pyro.set_rng_seed(20211205)
if smoke_test:
num_steps = 2
# Measure model and guide complexity.
num_latents = sum(
site["value"].numel()
for name, site in poutine.trace(guide).get_trace(dataset).iter_stochastic_nodes()
if not site["infer"].get("is_auxiliary")
)
num_params = sum(p.unconstrained().numel() for p in pyro.get_param_store().values())
print(f"Found {num_latents} latent variables and {num_params} learnable parameters")
# Save gradient norms during inference.
series = defaultdict(list)
def hook(g, series):
series.append(torch.linalg.norm(g.reshape(-1), math.inf).item())
for name, value in pyro.get_param_store().named_parameters():
value.register_hook(
functools.partial(hook, series=series[name + " grad"])
)
# Train the guide.
optim = ClippedAdam({"lr": lr, "lrd": 0.1 ** (1 / num_steps)})
svi = SVI(model, guide, optim, Trace_ELBO())
num_obs = int(dataset["counts"].count_nonzero())
for step in range(num_steps):
loss = svi.step(dataset) / num_obs
series["loss"].append(loss)
median = guide.median() # cheap for autoguides
for name, value in median.items():
if value.numel() == 1:
series[name + " mean"].append(float(value))
if step % log_every == 0:
print(f"step {step: >4d} loss = {loss:0.6g}")
# Plot series to assess convergence.
if plot:
plt.figure(figsize=(6, 6))
for name, Y in series.items():
if name == "loss":
plt.plot(Y, "k--", label=name, zorder=0)
elif name.endswith(" mean"):
plt.plot(Y, label=name, zorder=-1)
else:
plt.plot(Y, label=name, alpha=0.5, lw=1, zorder=-2)
plt.xlabel("SVI step")
plt.title("loss, scalar parameters, and gradient norms")
plt.yscale("log")
plt.xscale("symlog")
plt.xlim(0, None)
plt.legend(loc="best", fontsize=8)
plt.tight_layout()
%%time
guide = AutoNormal(model, init_scale=0.01)
fit_svi(model, guide)
Found 538452 latent variables and 1068600 learnable parameters step 0 loss = 273.123 step 100 loss = 63.2423 step 200 loss = 44.9539 step 300 loss = 34.8813 step 400 loss = 30.4243 step 500 loss = 27.5258 step 600 loss = 25.4543 step 700 loss = 23.9134 step 800 loss = 22.7201 step 900 loss = 21.8574 step 1000 loss = 21.2031 CPU times: user 3min 4s, sys: 2min 48s, total: 5min 52s Wall time: 1min 47s
After each change to the model or inference, you'll validate model outputs, closing Box's loop. In our running example we'll quantitiatively evaluate using the mean average error (MAE) over the last fully-observed time step.
def mae(true_counts, pred_probs):
"""Computes mean average error between counts and predicted probabilities."""
pred_counts = pred_probs * true_counts.sum(-1, True)
error = (true_counts - pred_counts).abs().sum(-1)
total = true_counts.sum(-1).clamp(min=1)
return (error / total).mean().item()
def evaluate(
model, guide, num_particles=100, location="USA / Massachusetts", time=-2
):
if smoke_test:
num_particles = 4
"""Evaluate posterior predictive accuracy at the last fully observed time step."""
with torch.no_grad(), poutine.mask(mask=False): # makes computations cheaper
with pyro.plate("particle", num_particles, dim=-3): # vectorizes
guide_trace = poutine.trace(guide).get_trace(dataset)
probs = poutine.replay(model, guide_trace)(dataset, predict=time)
probs = probs.squeeze().mean(0) # average over Monte Carlo samples
true_counts = dataset["counts"][time]
# Compute global and local KL divergence.
global_mae = mae(true_counts, probs)
i = dataset["locations"].index(location)
local_mae = mae(true_counts[i], probs[i])
return {"MAE (global)": global_mae, f"MAE ({location})": local_mae}
pprint(evaluate(model, guide))
{'MAE (USA / Massachusetts)': 0.26023179292678833, 'MAE (global)': 0.22586050629615784}
We'll also qualitatively evaluate using a volcano plot showing the effect size and statistical significance of each mutation's coefficient, and labeling the mutation with the most significant positive effect. We expect:
def plot_volcano(guide, num_particles=100):
if smoke_test:
num_particles = 4
with torch.no_grad(), poutine.mask(mask=False): # makes computations cheaper
with pyro.plate("particle", num_particles, dim=-3): # vectorizes
trace = poutine.trace(guide).get_trace(dataset)
trace = poutine.trace(poutine.replay(model, trace)).get_trace(dataset, -1)
coef = trace.nodes["coef"]["value"].cpu()
coef = coef.squeeze() * 0.01 # Scale factor as in the model.
mean = coef.mean(0)
std = coef.std(0)
z_score = mean.abs() / std
effect_size = mean.exp().numpy()
plt.figure(figsize=(6, 3))
plt.scatter(effect_size, z_score.numpy(), lw=0, s=5, alpha=0.5, color="darkred")
plt.yscale("symlog")
plt.ylim(0, None)
plt.xlabel("$R_m/R_{wt}$")
plt.ylabel("z-score")
i = int((mean / std).max(0).indices)
plt.text(effect_size[i], z_score[i] * 1.1, dataset["mutations"][i], ha="center", fontsize=8)
plt.title(f"Volcano plot of {len(mean)} mutations")
plot_volcano(guide)
In high-dimensional models, convergence can be slow and NANs arise easily, even when sampling from weakly informative priors. We recommend heuristically initializing a point estimate for each latent variable, aiming to initialize at something that is the right order of magnitude. Often you can initialize to a simple statistic of the data, e.g. a mean or standard deviation.
Pyro's autoguides provide a number of initialization strategies for initializing the location parameter of many variational families, specified as init_loc_fn
. You can create a custom initializer by accepting a pyro sample site dict and generating a sample from site["name"]
and site["fn"]
using e.g. site["fn"].shape()
, site["fn"].support
, site["fn"].mean
, or sampling via site["fn"].sample()
.
def init_loc_fn(site):
shape = site["fn"].shape()
if site["name"] == "coef":
return torch.randn(shape).sub_(0.5).mul(0.01)
if site["name"] == "init":
# Heuristically initialize based on data.
return dataset["counts"].mean(0).add(0.01).log()
return init_to_feasible(site) # fallback
As you evolve a model, you'll add and remove and rename latent variables. We find it useful to require inits for all latent variables, add a message to remind yourself to udpate the init_loc_fn
whenever the model changes.
def init_loc_fn(site):
shape = site["fn"].shape()
if site["name"].endswith("_scale"):
return torch.ones(shape)
if site["name"] == "coef":
return torch.randn(shape).sub_(0.5).mul(0.01)
if site["name"] == "rate":
return torch.zeros(shape)
if site["name"] == "init":
return dataset["counts"].mean(0).add(0.01).log()
raise NotImplementedError(f"TODO initialize latent variable {site['name']}")
%%time
guide = AutoNormal(model, init_loc_fn=init_loc_fn, init_scale=0.01)
fit_svi(model, guide, lr=0.02)
pprint(evaluate(model, guide))
plot_volcano(guide)
Found 538452 latent variables and 1068600 learnable parameters step 0 loss = 127.475 step 100 loss = 44.9544 step 200 loss = 31.4236 step 300 loss = 24.4205 step 400 loss = 20.6802 step 500 loss = 18.6063 step 600 loss = 17.2365 step 700 loss = 16.5067 step 800 loss = 16.001 step 900 loss = 15.5123 step 1000 loss = 18.8275 {'MAE (USA / Massachusetts)': 0.29367634654045105, 'MAE (global)': 0.2283070981502533} CPU times: user 3min 17s, sys: 2min 51s, total: 6min 9s Wall time: 1min 58s
Reparametrizing a model preserves its distribution while changing its geometry. Reparametrizing is simply a change of coordinates. When reparametrizing we aim to warp a model's geometry to remove correlations and to lift inconvenient topological manifolds into simpler higher dimensional flat Euclidean space.
Whereas many probabilistic programming languages require users to rewrite models to change coordinates, Pyro implements a library of about 15 different reparametrization effects including decentering (Gorinova et al. 2020), Haar wavelet transforms, and neural transport (Hoffman et al. 2019), as well as strategies to automatically apply effects and machinery to create custom reparametrization effects. Using these reparametrizers you can separate modeling from inference: first specify a model in a form that is natural to domain experts, then in inference code, reparametrize the model to have geometry that is more amenable to variational inference.
In our SARS-CoV-2 model, the geometry might improve if we change
- rate = pyro.sample("rate", dist.Normal(rate_loc, rate_scale))
+ rate = pyro.sample("rate", dist.Normal(0, 1)) * rate_scale + rate_loc
but that would make the model less interpretable. Instead we can reparametrize the model
reparam_model = poutine.reparam(model, config={"rate": LocScaleReparam()})
or even automatically apply a set of recommended reparameterizers
reparam_model = AutoReparam()(model)
Let's try reparametrizing both sites "rate" and "init". Note we'll create a fresh reparam_model
each time we train a guide, since the parameters are stored in that reparam_model
instance. Take care to use the reparam_model
in downstream prediction tasks like running evaluate(reparam_model, guide)
.
%%time
reparam_model = poutine.reparam(
model, {"rate": LocScaleReparam(), "init": LocScaleReparam()}
)
guide = AutoNormal(reparam_model, init_loc_fn=init_loc_fn, init_scale=0.01)
fit_svi(reparam_model, guide, lr=0.05)
pprint(evaluate(reparam_model, guide))
plot_volcano(guide)
Found 538452 latent variables and 1068602 learnable parameters step 0 loss = 127.368 step 100 loss = 20.2831 step 200 loss = 11.0703 step 300 loss = 9.64594 step 400 loss = 9.52988 step 500 loss = 9.09012 step 600 loss = 9.25454 step 700 loss = 8.60661 step 800 loss = 8.9332 step 900 loss = 8.64206 step 1000 loss = 8.56663 {'MAE (USA / Massachusetts)': 0.1336274892091751, 'MAE (global)': 0.1719919890165329} CPU times: user 4min 21s, sys: 3min 9s, total: 7min 31s Wall time: 2min 17s
When creating a new model, we recommend starting with mean field variational inference using an AutoNormal guide. This mean field guide is good at finding the neighborhood of your model's mode, but naively it ignores correlations between latent variables. A first step in capturing correlations is to reparametrize the model as above: using a LocScaleReparam
or HaarReparam
(where appropriate) already allows the guide to capture some correlations among latent variables.
The next step towards modeling uncertainty is to customize the variational family by trying other autoguides, building on EasyGuide, or creating a custom guide using Pyro primitives. We recommend increasing guide complexity gradually via these steps:
pyro.sample
, pyro.param
, and pyro.plate
. Given a partial_guide()
function that covers just a few latent variables, you can AutoGuideList.append(partial_guide)
just as you append autoguides.While a fully-custom guides built from pyro.sample
primitives offer the most flexible variational family, they are also the most brittle guides because each code change to the model or reparametrizer requires changes in the guide. The author recommends avoiding completely low-level guides and instead using AutoGuide
or EasyGuide
for at least some parts of the model, thereby speeding up model iteration.
Let's first try a simple AutoLowRankMultivariateNormal
guide.
%%time
reparam_model = poutine.reparam(
model, {"rate": LocScaleReparam(), "init": LocScaleReparam()}
)
guide = AutoLowRankMultivariateNormal(
reparam_model, init_loc_fn=init_loc_fn, init_scale=0.01, rank=100
)
fit_svi(reparam_model, guide, num_steps=10, log_every=1, plot=False)
# don't even bother to evaluate, since this is too slow.
Found 538452 latent variables and 54498602 learnable parameters step 0 loss = 128.329 step 1 loss = 126.172 step 2 loss = 124.691 step 3 loss = 123.609 step 4 loss = 123.317 step 5 loss = 121.567 step 6 loss = 120.513 step 7 loss = 121.759 step 8 loss = 120.844 step 9 loss = 121.641 CPU times: user 45.9 s, sys: 38.2 s, total: 1min 24s Wall time: 29 s
Yikes! This is quite slow and sometimes runs out of memory on GPU.
Let's make this cheaper by using AutoGuideList
to combine an AutoLowRankMultivariateNormal
guide over the most important variables rate_scale
, init_scale
, and coef
, together with a simple cheap AutoNormal
guide on the rest of the model (the expensive rate
and init
variables). The typical pattern is to create two views of the model with poutine.block, one exposing the target variables and the other hiding them.
%%time
reparam_model = poutine.reparam(
model, {"rate": LocScaleReparam(), "init": LocScaleReparam()}
)
guide = AutoGuideList(reparam_model)
mvn_vars = ["coef", "rate_scale", "coef_scale"]
guide.add(
AutoLowRankMultivariateNormal(
poutine.block(reparam_model, expose=mvn_vars),
init_loc_fn=init_loc_fn,
init_scale=0.01,
)
)
guide.add(
AutoNormal(
poutine.block(reparam_model, hide=mvn_vars),
init_loc_fn=init_loc_fn,
init_scale=0.01,
)
)
fit_svi(reparam_model, guide, lr=0.1)
pprint(evaluate(reparam_model, guide))
plot_volcano(guide)
Found 538452 latent variables and 1202987 learnable parameters step 0 loss = 832.956 step 100 loss = 11.9687 step 200 loss = 11.1152 step 300 loss = 9.60629 step 400 loss = 10.1724 step 500 loss = 9.18063 step 600 loss = 9.1669 step 700 loss = 9.06247 step 800 loss = 9.38853 step 900 loss = 9.12489 step 1000 loss = 8.93582 {'MAE (USA / Massachusetts)': 0.09685955196619034, 'MAE (global)': 0.16698431968688965} CPU times: user 4min 22s, sys: 3min 5s, total: 7min 28s Wall time: 2min 15s
Next let's create a custom guide for part of the model, just the rate
and init
parts. Since we'll want to use this with reparametrizers, we'll make the guide use the auxiliary latent variables created by poutine.reparam
, rather than the original rate
and init
variables. Let's see what these variables are named:
for name, site in poutine.trace(reparam_model).get_trace(
dataset
).iter_stochastic_nodes():
print(name)
rate_scale init_scale mutation coef strain place rate_decentered init_decentered
It looks like these new auxiliary variables are called rate_decentered
and init_decentered
.
def local_guide(dataset):
# Create learnable parameters.
T, P, S = dataset["counts"].shape
r_loc = pyro.param("rate_decentered_loc", lambda: torch.zeros(P, S))
i_loc = pyro.param("init_decentered_loc", lambda: torch.zeros(P, S))
skew = pyro.param("skew", lambda: torch.zeros(P, S)) # allows correlation
r_scale = pyro.param("rate_decentered_scale", lambda: torch.ones(P, S),
constraint=constraints.softplus_positive)
i_scale = pyro.param("init_decentered_scale", lambda: torch.ones(P, S),
constraint=constraints.softplus_positive)
# Sample local variables inside plates.
# Note plates are already created by the main guide, so we'll
# use the existing plates rather than calling pyro.plate(...).
with guide.plates["place"], guide.plates["strain"]:
samples = {}
samples["rate_decentered"] = pyro.sample(
"rate_decentered", dist.Normal(r_loc, r_scale)
)
i_loc = i_loc + skew * samples["rate_decentered"]
samples["init_decentered"] = pyro.sample(
"init_decentered", dist.Normal(i_loc, i_scale)
)
return samples
%%time
reparam_model = poutine.reparam(
model, {"rate": LocScaleReparam(), "init": LocScaleReparam()}
)
guide = AutoGuideList(reparam_model)
local_vars = ["rate_decentered", "init_decentered"]
guide.add(
AutoLowRankMultivariateNormal(
poutine.block(reparam_model, hide=local_vars),
init_loc_fn=init_loc_fn,
init_scale=0.01,
)
)
guide.add(local_guide)
fit_svi(reparam_model, guide, lr=0.1)
pprint(evaluate(reparam_model, guide))
plot_volcano(guide)
Found 538452 latent variables and 1468870 learnable parameters step 0 loss = 4804.42 step 100 loss = 31.7409 step 200 loss = 19.8206 step 300 loss = 15.2961 step 400 loss = 13.2222 step 500 loss = 12.1435 step 600 loss = 11.4291 step 700 loss = 10.9722 step 800 loss = 10.6209 step 900 loss = 10.3649 step 1000 loss = 10.1804 {'MAE (USA / Massachusetts)': 0.1159871369600296, 'MAE (global)': 0.1876191794872284} CPU times: user 4min 26s, sys: 3min 7s, total: 7min 33s Wall time: 2min 18s
We've seen how to use initialization, reparameterization, autoguides, and custom guides in a Bayesian workflow. For more examples of these pieces of machinery, we recommend exploring the Pyro codebase, e.g. search for "poutine.reparam" or "init_loc_fn" in the Pyro codebase.