!pip install -q numpyro arviz
import os
import warnings
import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
import jax.numpy as jnp
from jax import lax, random, vmap
from jax.scipy.special import logsumexp
import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import Predictive, SVI, Trace_ELBO, init_to_value, log_likelihood
from numpyro.infer.autoguide import AutoLaplaceApproximation
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
sppnames = [
"afarensis",
"africanus",
"habilis",
"boisei",
"rudolfensis",
"ergaster",
"sapiens",
]
brainvolcc = jnp.array([438, 452, 612, 521, 752, 871, 1350])
masskg = jnp.array([37.0, 35.5, 34.5, 41.5, 55.5, 61.0, 53.5])
d = pd.DataFrame({"species": sppnames, "brain": brainvolcc, "mass": masskg})
d["mass_std"] = (d.mass - d.mass.mean()) / d.mass.std()
d["brain_std"] = d.brain / d.brain.max()
def model(mass_std, brain_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 1))
b = numpyro.sample("b", dist.Normal(0, 10))
log_sigma = numpyro.sample("log_sigma", dist.Normal(0, 1))
mu = numpyro.deterministic("mu", a + b * mass_std)
numpyro.sample("brain_std", dist.Normal(mu, jnp.exp(log_sigma)), obs=brain_std)
m7_1 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m7_1,
optim.Adam(0.3),
Trace_ELBO(),
mass_std=d.mass_std.values,
brain_std=d.brain_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p7_1 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1593.52it/s, init loss: 115.9437, avg. loss [951-1000]: 3.6486]
def model(mass_std, brain_std):
intercept = numpyro.sample("intercept", dist.Normal(0, 10))
b_mass_std = numpyro.sample("b_mass_std", dist.Normal(0, 10))
sigma = numpyro.sample("sigma", dist.HalfCauchy(2))
mu = intercept + b_mass_std * mass_std
numpyro.sample("brain_std", dist.Normal(mu, sigma), obs=brain_std)
m7_1_OLS = AutoLaplaceApproximation(model)
svi = SVI(
model,
m7_1_OLS,
optim=optim.Adam(0.01),
loss=Trace_ELBO(),
mass_std=d.mass_std.values,
brain_std=d.brain_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p7_1_OLS = svi_result.params
post = m7_1_OLS.sample_posterior(random.PRNGKey(1), p7_1_OLS, sample_shape=(1000,))
100%|██████████| 1000/1000 [00:00<00:00, 1700.78it/s, init loss: 118.9001, avg. loss [951-1000]: 6.6802]
post = m7_1.sample_posterior(random.PRNGKey(12), p7_1, sample_shape=(1000,))
s = Predictive(m7_1.model, post)(random.PRNGKey(2), d.mass_std.values)
r = jnp.mean(s["brain_std"], 0) - d.brain_std.values
resid_var = jnp.var(r, ddof=1)
outcome_var = jnp.var(d.brain_std.values, ddof=1)
1 - resid_var / outcome_var
DeviceArray(0.45580828, dtype=float32)
def R2_is_bad(quap_fit):
quap, params = quap_fit
post = quap.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,))
s = Predictive(quap.model, post)(random.PRNGKey(2), d.mass_std.values)
r = jnp.mean(s["brain_std"], 0) - d.brain_std.values
return 1 - jnp.var(r, ddof=1) / jnp.var(d.brain_std.values, ddof=1)
def model(mass_std, brain_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 1))
b = numpyro.sample("b", dist.Normal(0, 10).expand([2]))
log_sigma = numpyro.sample("log_sigma", dist.Normal(0, 1))
mu = numpyro.deterministic("mu", a + b[0] * mass_std + b[1] * mass_std**2)
numpyro.sample("brain_std", dist.Normal(mu, jnp.exp(log_sigma)), obs=brain_std)
m7_2 = AutoLaplaceApproximation(
model, init_loc_fn=init_to_value(values={"b": jnp.repeat(0.0, 2)})
)
svi = SVI(
model,
m7_2,
optim.Adam(0.3),
Trace_ELBO(),
mass_std=d.mass_std.values,
brain_std=d.brain_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 2000)
p7_2 = svi_result.params
100%|██████████| 2000/2000 [00:00<00:00, 3029.44it/s, init loss: 19.0172, avg. loss [1901-2000]: 6.6510]
def model(mass_std, brain_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 1))
b = numpyro.sample("b", dist.Normal(0, 10).expand([3]))
log_sigma = numpyro.sample("log_sigma", dist.Normal(0, 1))
mu = numpyro.deterministic(
"mu", a + b[0] * mass_std + b[1] * mass_std**2 + b[2] * mass_std**3
)
numpyro.sample("brain_std", dist.Normal(mu, jnp.exp(log_sigma)), obs=brain_std)
m7_3 = AutoLaplaceApproximation(
model, init_loc_fn=init_to_value(values={"b": jnp.repeat(0.0, 3)})
)
svi = SVI(
model,
m7_3,
optim.Adam(0.01),
Trace_ELBO(),
mass_std=d.mass_std.values,
brain_std=d.brain_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 2000)
p7_3 = svi_result.params
def model(mass_std, brain_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 1))
b = numpyro.sample("b", dist.Normal(0, 10).expand([4]))
log_sigma = numpyro.sample("log_sigma", dist.Normal(0, 1))
mu = numpyro.deterministic(
"mu", a + jnp.sum(b * jnp.power(mass_std[..., None], jnp.arange(1, 5)), -1)
)
numpyro.sample("brain_std", dist.Normal(mu, jnp.exp(log_sigma)), obs=brain_std)
m7_4 = AutoLaplaceApproximation(
model, init_loc_fn=init_to_value(values={"b": jnp.repeat(0.0, 4)})
)
svi = SVI(
model,
m7_4,
optim.Adam(0.01),
Trace_ELBO(),
mass_std=d.mass_std.values,
brain_std=d.brain_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 2000)
p7_4 = svi_result.params
def model(mass_std, brain_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 1))
b = numpyro.sample("b", dist.Normal(0, 10).expand([5]))
log_sigma = numpyro.sample("log_sigma", dist.Normal(0, 1))
mu = numpyro.deterministic(
"mu", a + jnp.sum(b * jnp.power(mass_std[..., None], jnp.arange(1, 6)), -1)
)
numpyro.sample("brain_std", dist.Normal(mu, jnp.exp(log_sigma)), obs=brain_std)
m7_5 = AutoLaplaceApproximation(
model, init_loc_fn=init_to_value(values={"b": jnp.repeat(0.0, 5)})
)
svi = SVI(
model,
m7_5,
optim.Adam(0.01),
Trace_ELBO(),
mass_std=d.mass_std.values,
brain_std=d.brain_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 2000)
p7_5 = svi_result.params
100%|██████████| 2000/2000 [00:00<00:00, 2905.53it/s, init loss: 22.2387, avg. loss [1901-2000]: 8.8866] 100%|██████████| 2000/2000 [00:00<00:00, 2911.48it/s, init loss: 25.4603, avg. loss [1901-2000]: 10.8298] 100%|██████████| 2000/2000 [00:00<00:00, 2650.21it/s, init loss: 28.6818, avg. loss [1901-2000]: 8.2608]
def model(mass_std, brain_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 1))
b = numpyro.sample("b", dist.Normal(0, 10).expand([6]))
log_sigma = numpyro.sample("log_sigma", dist.Normal(0, 1))
mu = numpyro.deterministic(
"mu", a + jnp.sum(b * jnp.power(mass_std[..., None], jnp.arange(1, 7)), -1)
)
numpyro.sample("brain_std", dist.Normal(mu, jnp.exp(log_sigma)), obs=brain_std)
m7_6 = AutoLaplaceApproximation(
model, init_loc_fn=init_to_value(values={"b": jnp.repeat(0.0, 6)})
)
svi = SVI(
model,
m7_6,
optim.Adam(0.003),
Trace_ELBO(),
mass_std=d.mass_std.values,
brain_std=d.brain_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 5000)
p7_6 = svi_result.params
100%|██████████| 5000/5000 [00:00<00:00, 5220.55it/s, init loss: 31.9033, avg. loss [4751-5000]: 9.2877]
post = m7_1.sample_posterior(random.PRNGKey(1), p7_1, sample_shape=(1000,))
post.pop("mu")
mass_seq = jnp.linspace(d.mass_std.min(), d.mass_std.max(), num=100)
l = Predictive(m7_1.model, post, return_sites=["mu"])(
random.PRNGKey(2), mass_std=mass_seq
)["mu"]
mu = jnp.mean(l, 0)
ci = jnp.percentile(l, jnp.array([5.5, 94.5]), 0)
az.plot_pair(d[["mass_std", "brain_std"]].to_dict("list"))
plt.plot(mass_seq, mu, "k")
plt.fill_between(mass_seq, ci[0], ci[1], color="k", alpha=0.2)
plt.title("m7.1: R^2 = {:0.2f}".format(R2_is_bad((m7_1, p7_1)).item()))
plt.show()
i = 1
d_minus_i = d.drop(i)
p = jnp.array([0.3, 0.7])
-jnp.sum(p * jnp.log(p))
DeviceArray(0.61086434, dtype=float32)
def lppd_fn(seed, quad, params, num_samples=1000):
post = quad.sample_posterior(random.PRNGKey(1), params, sample_shape=(num_samples,))
logprob = log_likelihood(quad.model, post, d.mass_std.values, d.brain_std.values)
logprob = logprob["brain_std"]
return logsumexp(logprob, 0) - jnp.log(logprob.shape[0])
lppd_fn(random.PRNGKey(1), m7_1, p7_1, int(1e4))
DeviceArray([ 0.6177206 , 0.6550026 , 0.54556274, 0.6307907 , 0.4702301 , 0.43731594, -0.8560524 ], dtype=float32)
post = m7_1.sample_posterior(random.PRNGKey(1), p7_1, sample_shape=(int(1e4),))
logprob = log_likelihood(m7_1.model, post, d.mass_std.values, d.brain_std.values)
logprob = logprob["brain_std"]
n = logprob.shape[1]
ns = logprob.shape[0]
f = lambda i: logsumexp(logprob[:, i]) - jnp.log(ns)
lppd = vmap(f)(jnp.arange(n))
lppd
DeviceArray([ 0.6177206 , 0.6550026 , 0.54556274, 0.6307907 , 0.4702301 , 0.43731594, -0.8560524 ], dtype=float32)
[
jnp.sum(lppd_fn(random.PRNGKey(1), m[0], m[1])).item()
for m in (
(m7_1, p7_1),
(m7_2, p7_2),
(m7_3, p7_3),
(m7_4, p7_4),
(m7_5, p7_5),
(m7_6, p7_6),
)
]
UserWarning: Hessian of log posterior at the MAP point is singular. Posterior samples from AutoLaplaceApproxmiation will be constant (equal to the MAP point). Please consider using an AutoNormal guide.
[2.500570297241211, 2.5938844680786133, 3.6698102951049805, 5.338682174682617, 14.092883110046387, 19.871124267578125]
def model(mm, y, b_sigma):
a = numpyro.param("a", jnp.array([0.0]))
Bvec = a
k = mm.shape[1]
if k > 1:
b = numpyro.sample("b", dist.Normal(0, b_sigma).expand([k - 1]))
Bvec = jnp.concatenate([Bvec, b])
mu = jnp.matmul(mm, Bvec)
numpyro.sample("y", dist.Normal(mu, 1), obs=y)
def sim_train_test(i, N=20, k=3, rho=[0.15, -0.4], b_sigma=100):
n_dim = max(k, 3)
Rho = jnp.identity(n_dim)
Rho = Rho.at[1 : len(rho) + 1, 0].set(jnp.array(rho))
Rho = Rho.at[0, 1 : len(rho) + 1].set(jnp.array(rho))
X_train = dist.MultivariateNormal(jnp.zeros(n_dim), Rho).sample(
random.fold_in(random.PRNGKey(0), i), (N,)
)
mm_train = jnp.ones((N, 1))
if k > 1:
mm_train = jnp.concatenate([mm_train, X_train[:, 1:k]], axis=1)
if k > 1:
m = AutoLaplaceApproximation(
model, init_loc_fn=init_to_value(values={"b": jnp.zeros(k - 1)})
)
else:
m = lambda mm, y, b_sigma: None
svi = SVI(
model,
m,
optim.Adam(0.3),
Trace_ELBO(),
mm=mm_train,
y=X_train[:, 0],
b_sigma=b_sigma,
)
svi_result = svi.run(random.fold_in(random.PRNGKey(1), i), 1000, progress_bar=False)
params = svi_result.params
coefs = params["a"]
if k > 1:
coefs = jnp.concatenate([coefs, m.median(params)["b"]])
logprob = dist.Normal(jnp.matmul(mm_train, coefs)).log_prob(X_train[:, 0])
dev_train = (-2) * jnp.sum(logprob)
X_test = dist.MultivariateNormal(jnp.zeros(n_dim), Rho).sample(
random.fold_in(random.PRNGKey(2), i), (N,)
)
mm_test = jnp.ones((N, 1))
if k > 1:
mm_test = jnp.concatenate([mm_test, X_test[:, 1:k]], axis=1)
logprob = dist.Normal(jnp.matmul(mm_test, coefs)).log_prob(X_test[:, 0])
dev_test = (-2) * jnp.sum(logprob)
return jnp.stack([dev_train, dev_test])
def dev_fn(N, k):
print(k)
r = lax.map(lambda i: sim_train_test(i, N, k), jnp.arange((int(1e4))))
return jnp.concatenate([jnp.mean(r, 0), jnp.std(r, 0)])
N = 20
kseq = range(1, 6)
dev = jnp.stack([dev_fn(N, k) for k in kseq], axis=1)
1 2 3 4 5
def dev_fn(N, k):
print(k)
r = vmap(lambda i: sim_train_test(i, N, k))(jnp.arange((int(1e4))))
return jnp.concatenate([jnp.mean(r, 0), jnp.std(r, 0)])
plt.subplot(
ylim=(jnp.min(dev[0]).item() - 5, jnp.max(dev[0]).item() + 12),
xlim=(0.9, 5.2),
xlabel="number of parameters",
ylabel="deviance",
)
plt.title("N = {}".format(N))
plt.scatter(jnp.arange(1, 6), dev[0], s=80, color="b")
plt.scatter(jnp.arange(1.1, 6), dev[1], s=80, color="k")
pts_int = (dev[0] - dev[2], dev[0] + dev[2])
pts_out = (dev[1] - dev[3], dev[1] + dev[3])
plt.vlines(jnp.arange(1, 6), pts_int[0], pts_int[1], color="b")
plt.vlines(jnp.arange(1.1, 6), pts_out[0], pts_out[1], color="k")
plt.annotate(
"in", (2, dev[0][1]), xytext=(-25, -5), textcoords="offset pixels", color="b"
)
plt.annotate("out", (2.1, dev[1][1]), xytext=(10, -5), textcoords="offset pixels")
plt.annotate(
"+1SD",
(2.1, pts_out[1][1]),
xytext=(10, -5),
textcoords="offset pixels",
fontsize=12,
)
plt.annotate(
"-1SD",
(2.1, pts_out[0][1]),
xytext=(10, -5),
textcoords="offset pixels",
fontsize=12,
)
plt.show()
cars = pd.read_csv("../data/cars.csv", sep=",")
def model(speed, cars_dist):
a = numpyro.sample("a", dist.Normal(0, 100))
b = numpyro.sample("b", dist.Normal(0, 10))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + b * speed
numpyro.sample("dist", dist.Normal(mu, sigma), obs=cars_dist)
m = AutoLaplaceApproximation(model)
svi = SVI(
model,
m,
optim.Adam(1),
Trace_ELBO(),
speed=cars.speed.values,
cars_dist=cars.dist.values,
)
svi_result = svi.run(random.PRNGKey(0), 5000)
params = svi_result.params
post = m.sample_posterior(random.PRNGKey(94), params, sample_shape=(1000,))
100%|██████████| 5000/5000 [00:00<00:00, 5122.07it/s, init loss: 210883.8125, avg. loss [4751-5000]: 227.2560]
n_samples = 1000
def logprob_fn(s):
mu = post["a"][s] + post["b"][s] * cars.speed.values
return dist.Normal(mu, post["sigma"][s]).log_prob(cars.dist.values)
logprob = vmap(logprob_fn, out_axes=1)(jnp.arange(n_samples))
n_cases = cars.shape[0]
lppd = logsumexp(logprob, 1) - jnp.log(n_samples)
pWAIC = jnp.var(logprob, 1)
-2 * (jnp.sum(lppd) - jnp.sum(pWAIC))
DeviceArray(422.99088, dtype=float32)
waic_vec = -2 * (lppd - pWAIC)
jnp.sqrt(n_cases * jnp.var(waic_vec))
DeviceArray(17.235405, dtype=float32)
with numpyro.handlers.seed(rng_seed=71):
# number of plants
N = 100
# simulate initial heights
h0 = numpyro.sample("h0", dist.Normal(10, 2).expand([N]))
# assign treatments and simulate fungus and growth
treatment = jnp.repeat(jnp.arange(2), repeats=N // 2)
fungus = numpyro.sample(
"fungus", dist.Binomial(total_count=1, probs=(0.5 - treatment * 0.4))
)
h1 = h0 + numpyro.sample("diff", dist.Normal(5 - 3 * fungus))
# compose a clean data frame
d = pd.DataFrame({"h0": h0, "h1": h1, "treatment": treatment, "fungus": fungus})
def model(h0, h1):
p = numpyro.sample("p", dist.LogNormal(0, 0.25))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = h0 * p
numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1)
m6_6 = AutoLaplaceApproximation(model)
svi = SVI(model, m6_6, optim.Adam(0.1), Trace_ELBO(), h0=d.h0.values, h1=d.h1.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_6 = svi_result.params
def model(treatment, fungus, h0, h1):
a = numpyro.sample("a", dist.LogNormal(0, 0.2))
bt = numpyro.sample("bt", dist.Normal(0, 0.5))
bf = numpyro.sample("bf", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
p = a + bt * treatment + bf * fungus
mu = h0 * p
numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1)
m6_7 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_7,
optim.Adam(0.3),
Trace_ELBO(),
treatment=d.treatment.values,
fungus=d.fungus.values,
h0=d.h0.values,
h1=d.h1.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_7 = svi_result.params
def model(treatment, h0, h1):
a = numpyro.sample("a", dist.LogNormal(0, 0.2))
bt = numpyro.sample("bt", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
p = a + bt * treatment
mu = h0 * p
numpyro.sample("h1", dist.Normal(mu, sigma), obs=h1)
m6_8 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m6_8,
optim.Adam(1),
Trace_ELBO(),
treatment=d.treatment.values,
h0=d.h0.values,
h1=d.h1.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p6_8 = svi_result.params
post = m6_7.sample_posterior(random.PRNGKey(11), p6_7, sample_shape=(1000,))
logprob = log_likelihood(
m6_7.model,
post,
treatment=d.treatment.values,
fungus=d.fungus.values,
h0=d.h0.values,
h1=d.h1.values,
)
az6_7 = az.from_dict(sample_stats={"log_likelihood": logprob["h1"][None, ...]})
az.waic(az6_7, scale="deviance")
100%|██████████| 1000/1000 [00:00<00:00, 1252.99it/s, init loss: 279.8950, avg. loss [951-1000]: 204.3793] 100%|██████████| 1000/1000 [00:01<00:00, 821.21it/s, init loss: 151456.4062, avg. loss [951-1000]: 168.2294] 100%|██████████| 1000/1000 [00:01<00:00, 981.14it/s, init loss: 87469.1172, avg. loss [951-1000]: 198.4736] UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details
Computed from 1000 by 100 log-likelihood matrix Estimate SE deviance_waic 337.37 11.86 p_waic 3.36 - There has been a warning during the calculation. Please check the results.
post = m6_6.sample_posterior(random.PRNGKey(77), p6_6, sample_shape=(1000,))
logprob = log_likelihood(m6_6.model, post, h0=d.h0.values, h1=d.h1.values)
az6_6 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
post = m6_7.sample_posterior(random.PRNGKey(77), p6_7, sample_shape=(1000,))
logprob = log_likelihood(
m6_7.model,
post,
treatment=d.treatment.values,
fungus=d.fungus.values,
h0=d.h0.values,
h1=d.h1.values,
)
az6_7 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
post = m6_8.sample_posterior(random.PRNGKey(77), p6_8, sample_shape=(1000,))
logprob = log_likelihood(
m6_8.model, post, treatment=d.treatment.values, h0=d.h0.values, h1=d.h1.values
)
az6_8 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
az.compare({"m6.6": az6_6, "m6.7": az6_7, "m6.8": az6_8}, ic="waic", scale="deviance")
UserWarning: The default method used to estimate the weights for each model,has changed from BB-pseudo-BMA to stacking UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details
rank | waic | p_waic | d_waic | weight | se | dse | warning | waic_scale | |
---|---|---|---|---|---|---|---|---|---|
m6.7 | 0 | 337.244430 | 3.308165 | 0.000000 | 1.000000e+00 | 11.844569 | 0.000000 | True | deviance |
m6.8 | 1 | 399.758470 | 3.089429 | 62.514039 | 0.000000e+00 | 14.941817 | 13.865603 | True | deviance |
m6.6 | 2 | 409.200795 | 1.712087 | 71.956364 | 7.406298e-12 | 12.402004 | 12.789497 | False | deviance |
post = m6_7.sample_posterior(random.PRNGKey(91), p6_7, sample_shape=(1000,))
logprob = log_likelihood(
m6_7.model,
post,
treatment=d.treatment.values,
fungus=d.fungus.values,
h0=d.h0.values,
h1=d.h1.values,
)
az6_7 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
waic_m6_7 = az.waic(az6_7, pointwise=True, scale="deviance")
post = m6_8.sample_posterior(random.PRNGKey(91), p6_8, sample_shape=(1000,))
logprob = log_likelihood(
m6_8.model, post, treatment=d.treatment.values, h0=d.h0.values, h1=d.h1.values
)
az6_8 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
waic_m6_8 = az.waic(az6_8, pointwise=True, scale="deviance")
n = waic_m6_7.n_data_points
diff_m6_7_m6_8 = waic_m6_7.waic_i.values - waic_m6_8.waic_i.values
jnp.sqrt(n * jnp.var(diff_m6_7_m6_8))
UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details
DeviceArray(13.789175, dtype=float32)
40.0 + jnp.array([-1, 1]) * 10.4 * 2.6
DeviceArray([12.960003, 67.03999 ], dtype=float32, weak_type=True)
compare = az.compare(
{"m6.6": az6_6, "m6.7": az6_7, "m6.8": az6_8}, ic="waic", scale="deviance"
)
az.plot_compare(compare)
plt.show()
UserWarning: The default method used to estimate the weights for each model,has changed from BB-pseudo-BMA to stacking UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details
post = m6_6.sample_posterior(random.PRNGKey(92), p6_6, sample_shape=(1000,))
logprob = log_likelihood(m6_6.model, post, h0=d.h0.values, h1=d.h1.values)
az6_6 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
waic_m6_6 = az.waic(az6_6, pointwise=True, scale="deviance")
diff_m6_6_m6_8 = waic_m6_6.waic_i.values - waic_m6_8.waic_i.values
jnp.sqrt(n * jnp.var(diff_m6_6_m6_8))
DeviceArray(7.524193, dtype=float32)
post = m6_6.sample_posterior(random.PRNGKey(93), p6_6, sample_shape=(1000,))
logprob = log_likelihood(m6_6.model, post, h0=d.h0.values, h1=d.h1.values)
az6_6 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
waic_m6_6 = az.waic(az6_6, pointwise=True, scale="deviance")
post = m6_7.sample_posterior(random.PRNGKey(93), p6_7, sample_shape=(1000,))
logprob = log_likelihood(
m6_7.model,
post,
treatment=d.treatment.values,
fungus=d.fungus.values,
h0=d.h0.values,
h1=d.h1.values,
)
az6_7 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
waic_m6_7 = az.waic(az6_7, pointwise=True, scale="deviance")
post = m6_8.sample_posterior(random.PRNGKey(93), p6_8, sample_shape=(1000,))
logprob = log_likelihood(
m6_8.model, post, treatment=d.treatment.values, h0=d.h0.values, h1=d.h1.values
)
az6_8 = az.from_dict({}, log_likelihood={"h1": logprob["h1"][None, ...]})
waic_m6_8 = az.waic(az6_8, pointwise=True, scale="deviance")
dSE = lambda waic1, waic2: jnp.sqrt(
n * jnp.var(waic1.waic_i.values - waic2.waic_i.values)
)
data = {"m6.6": waic_m6_6, "m6.7": waic_m6_7, "m6.8": waic_m6_8}
pd.DataFrame(
{
row: {col: dSE(row_val, col_val) for col, col_val in data.items()}
for row, row_val in data.items()
}
)
UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details
m6.6 | m6.7 | m6.8 | |
---|---|---|---|
m6.6 | 0.0 | 12.7080345 | 7.5581884 |
m6.7 | 12.7080345 | 0.0 | 13.690906 |
m6.8 | 7.5581884 | 13.690906 | 0.0 |
WaffleDivorce = pd.read_csv("../data/WaffleDivorce.csv", sep=";")
d = WaffleDivorce
d["A"] = d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std())
d["D"] = d.Divorce.pipe(lambda x: (x - x.mean()) / x.std())
d["M"] = d.Marriage.pipe(lambda x: (x - x.mean()) / x.std())
def model(A, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bA * A)
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
m5_1 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_1, optim.Adam(1), Trace_ELBO(), A=d.A.values, D=d.D.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_1 = svi_result.params
def model(M, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bM * M
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
m5_2 = AutoLaplaceApproximation(model)
svi = SVI(model, m5_2, optim.Adam(1), Trace_ELBO(), M=d.M.values, D=d.D.values)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_2 = svi_result.params
def model(M, A, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bM * M + bA * A)
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
m5_3 = AutoLaplaceApproximation(model)
svi = SVI(
model, m5_3, optim.Adam(1), Trace_ELBO(), M=d.M.values, A=d.A.values, D=d.D.values
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_3 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1334.71it/s, init loss: 2138.6682, avg. loss [951-1000]: 60.6515] 100%|██████████| 1000/1000 [00:00<00:00, 1365.18it/s, init loss: 962.7464, avg. loss [951-1000]: 67.4809] 100%|██████████| 1000/1000 [00:00<00:00, 1106.15it/s, init loss: 3201.7393, avg. loss [951-1000]: 60.7879]
post = m5_1.sample_posterior(random.PRNGKey(24071847), p5_1, sample_shape=(1000,))
logprob = log_likelihood(m5_1.model, post, A=d.A.values, D=d.D.values)["D"]
az5_1 = az.from_dict(
posterior={k: v[None, ...] for k, v in post.items()},
log_likelihood={"D": logprob[None, ...]},
)
post = m5_2.sample_posterior(random.PRNGKey(24071847), p5_2, sample_shape=(1000,))
logprob = log_likelihood(m5_2.model, post, M=d.M.values, D=d.D.values)["D"]
az5_2 = az.from_dict(
posterior={k: v[None, ...] for k, v in post.items()},
log_likelihood={"D": logprob[None, ...]},
)
post = m5_3.sample_posterior(random.PRNGKey(24071847), p5_3, sample_shape=(1000,))
logprob = log_likelihood(m5_3.model, post, A=d.A.values, M=d.M.values, D=d.D.values)[
"D"
]
az5_3 = az.from_dict(
posterior={k: v[None, ...] for k, v in post.items()},
log_likelihood={"D": logprob[None, ...]},
)
az.compare({"m5.1": az5_1, "m5.2": az5_2, "m5.3": az5_3}, ic="waic", scale="deviance")
UserWarning: The default method used to estimate the weights for each model,has changed from BB-pseudo-BMA to stacking UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details
rank | waic | p_waic | d_waic | weight | se | dse | warning | waic_scale | |
---|---|---|---|---|---|---|---|---|---|
m5.1 | 0 | 126.515847 | 4.108174 | 0.000000 | 8.879604e-01 | 13.580529 | 0.000000 | True | deviance |
m5.3 | 1 | 128.609601 | 5.490680 | 2.093754 | 5.354823e-15 | 14.173850 | 1.047467 | True | deviance |
m5.2 | 2 | 139.775924 | 3.282653 | 13.260077 | 1.120396e-01 | 10.458673 | 9.845431 | True | deviance |
PSIS_m5_3 = az.loo(az5_3, pointwise=True, scale="deviance")
WAIC_m5_3 = az.waic(az5_3, pointwise=True, scale="deviance")
penalty = az5_3.log_likelihood.stack(sample=("chain", "draw")).var(dim="sample")
plt.plot(PSIS_m5_3.pareto_k.values, penalty.D.values, "o", mfc="none")
plt.gca().set(xlabel="PSIS Pareto k", ylabel="WAIC penalty")
plt.show()
UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.7 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations. UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details
def model(M, A, D=None):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bM * M + bA * A
numpyro.sample("D", dist.StudentT(2, mu, sigma), obs=D)
m5_3t = AutoLaplaceApproximation(model)
svi = SVI(
model,
m5_3t,
optim.Adam(0.3),
Trace_ELBO(),
M=d.M.values,
A=d.A.values,
D=d.D.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p5_3t = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1067.98it/s, init loss: 194.5655, avg. loss [951-1000]: 63.3271]