!pip install -q numpyro arviz
import math
import os
import warnings
import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import Predictive, SVI, Trace_ELBO, log_likelihood
from numpyro.infer.autoguide import AutoLaplaceApproximation
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
rugged = pd.read_csv("../data/rugged.csv", sep=";")
d = rugged
# make log version of outcome
d["log_gdp"] = d["rgdppc_2000"].apply(math.log)
# extract countries with GDP data
dd = d[d["rgdppc_2000"].notnull()].copy()
# rescale variables
dd["log_gdp_std"] = dd.log_gdp / dd.log_gdp.mean()
dd["rugged_std"] = dd.rugged / dd.rugged.max()
def model(rugged_std, log_gdp_std=None):
a = numpyro.sample("a", dist.Normal(1, 1))
b = numpyro.sample("b", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + b * (rugged_std - 0.215))
numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)
m8_1 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m8_1,
optim.Adam(0.1),
Trace_ELBO(),
rugged_std=dd.rugged_std.values,
log_gdp_std=dd.log_gdp_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_1 = svi_result.params
100%|██████████| 1000/1000 [00:01<00:00, 944.10it/s, init loss: 810.4496, avg. loss [951-1000]: -91.4252]
predictive = Predictive(m8_1.model, num_samples=1000, return_sites=["a", "b", "sigma"])
prior = predictive(random.PRNGKey(7), rugged_std=0)
# set up the plot dimensions
plt.subplot(xlim=(0, 1), ylim=(0.5, 1.5), xlabel="ruggedness", ylabel="log GDP")
plt.gca().axhline(dd.log_gdp_std.min(), ls="--")
plt.gca().axhline(dd.log_gdp_std.max(), ls="--")
# draw 50 lines from the prior
rugged_seq = jnp.linspace(-0.1, 1.1, num=30)
mu = Predictive(m8_1.model, prior, return_sites=["mu"])(
random.PRNGKey(7), rugged_std=rugged_seq
)["mu"]
for i in range(50):
plt.plot(rugged_seq, mu[i], "k", alpha=0.3)
jnp.sum(jnp.abs(prior["b"]) > 0.6) / prior["b"].shape[0]
DeviceArray(0.564, dtype=float32)
def model(rugged_std, log_gdp_std=None):
a = numpyro.sample("a", dist.Normal(1, 0.1))
b = numpyro.sample("b", dist.Normal(0, 0.3))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + b * (rugged_std - 0.215))
numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)
m8_1 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m8_1,
optim.Adam(0.1),
Trace_ELBO(),
rugged_std=dd.rugged_std.values,
log_gdp_std=dd.log_gdp_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_1 = svi_result.params
100%|██████████| 1000/1000 [00:01<00:00, 950.36it/s, init loss: 852.5239, avg. loss [951-1000]: -94.8615]
post = m8_1.sample_posterior(random.PRNGKey(1), p8_1, sample_shape=(1000,))
print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat a 1.00 0.01 1.00 0.98 1.02 931.50 1.00 b 0.00 0.06 0.00 -0.08 0.10 1111.63 1.00 sigma 0.14 0.01 0.14 0.13 0.15 949.29 1.00
# make variable to index Africa (0) or not (1)
dd["cid"] = jnp.where(dd.cont_africa.values == 1, 0, 1)
def model(cid, rugged_std, log_gdp_std=None):
a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2]))
b = numpyro.sample("b", dist.Normal(0, 0.3))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a[cid] + b * (rugged_std - 0.215))
numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)
m8_2 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m8_2,
optim.Adam(0.1),
Trace_ELBO(),
cid=dd.cid.values,
rugged_std=dd.rugged_std.values,
log_gdp_std=dd.log_gdp_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_2 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1061.70it/s, init loss: 1785.1527, avg. loss [951-1000]: -127.6097]
post = m8_1.sample_posterior(random.PRNGKey(2), p8_1, sample_shape=(1000,))
logprob = log_likelihood(
m8_1.model, post, rugged_std=dd.rugged_std.values, log_gdp_std=dd.log_gdp_std.values
)
az8_1 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
post = m8_2.sample_posterior(random.PRNGKey(2), p8_2, sample_shape=(1000,))
logprob = log_likelihood(
m8_2.model,
post,
rugged_std=dd.rugged_std.values,
cid=dd.cid.values,
log_gdp_std=dd.log_gdp_std.values,
)
az8_2 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
az.compare({"m8.1": az8_1, "m8.2": az8_2}, ic="waic", scale="deviance")
UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details
rank | waic | p_waic | d_waic | weight | se | dse | warning | waic_scale | |
---|---|---|---|---|---|---|---|---|---|
m8.2 | 0 | -252.36 | 4.15389 | 0 | 0.9998 | 13.2547 | 0 | True | deviance |
m8.1 | 1 | -188.818 | 2.65329 | 63.5418 | 0.000200227 | 14.8249 | 14.9592 | False | deviance |
post = m8_2.sample_posterior(random.PRNGKey(1), p8_2, sample_shape=(1000,))
print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat a[0] 0.88 0.02 0.88 0.86 0.90 1049.96 1.00 a[1] 1.05 0.01 1.05 1.03 1.07 824.00 1.00 b -0.05 0.05 -0.05 -0.13 0.02 999.08 1.00 sigma 0.11 0.01 0.11 0.10 0.12 961.35 1.00
post = m8_2.sample_posterior(random.PRNGKey(1), p8_2, sample_shape=(1000,))
diff_a1_a2 = post["a"][:, 0] - post["a"][:, 1]
jnp.percentile(diff_a1_a2, q=jnp.array([5.5, 94.5]))
DeviceArray([-0.19981882, -0.13967244], dtype=float32)
rugged_seq = jnp.linspace(start=-1, stop=1.1, num=30)
# compute mu over samples, fixing cid=1
post.pop("mu")
predictive = Predictive(m8_2.model, post, return_sites=["mu"])
mu_NotAfrica = predictive(random.PRNGKey(2), cid=1, rugged_std=rugged_seq)["mu"]
# compute mu over samples, fixing cid=0
mu_Africa = predictive(random.PRNGKey(2), cid=0, rugged_std=rugged_seq)["mu"]
# summarize to means and intervals
mu_NotAfrica_mu = jnp.mean(mu_NotAfrica, 0)
mu_NotAfrica_ci = jnp.percentile(mu_NotAfrica, q=jnp.array([1.5, 98.5]), axis=0)
mu_Africa_mu = jnp.mean(mu_Africa, 0)
mu_Africa_ci = jnp.percentile(mu_Africa, q=jnp.array([1.5, 98.5]), axis=0)
def model(cid, rugged_std, log_gdp_std=None):
a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2]))
b = numpyro.sample("b", dist.Normal(0, 0.3).expand([2]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a[cid] + b[cid] * (rugged_std - 0.215))
numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std)
m8_3 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m8_3,
optim.Adam(0.1),
Trace_ELBO(),
cid=dd.cid.values,
rugged_std=dd.rugged_std.values,
log_gdp_std=dd.log_gdp_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_3 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1080.73it/s, init loss: 1670.7773, avg. loss [951-1000]: -132.0968]
post = m8_3.sample_posterior(random.PRNGKey(1), p8_3, sample_shape=(1000,))
print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat a[0] 0.89 0.02 0.89 0.86 0.91 1009.20 1.00 a[1] 1.05 0.01 1.05 1.04 1.07 755.33 1.00 b[0] 0.13 0.07 0.13 0.01 0.24 1045.06 1.00 b[1] -0.15 0.06 -0.14 -0.23 -0.05 1003.36 1.00 sigma 0.11 0.01 0.11 0.10 0.12 810.01 1.00
post = m8_1.sample_posterior(random.PRNGKey(2), p8_1, sample_shape=(1000,))
logprob = log_likelihood(
m8_1.model, post, rugged_std=dd.rugged_std.values, log_gdp_std=dd.log_gdp_std.values
)
az8_1 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
post = m8_2.sample_posterior(random.PRNGKey(2), p8_2, sample_shape=(1000,))
logprob = log_likelihood(
m8_2.model,
post,
rugged_std=dd.rugged_std.values,
cid=dd.cid.values,
log_gdp_std=dd.log_gdp_std.values,
)
az8_3 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
post = m8_3.sample_posterior(random.PRNGKey(2), p8_3, sample_shape=(1000,))
logprob = log_likelihood(
m8_3.model,
post,
rugged_std=dd.rugged_std.values,
cid=dd.cid.values,
log_gdp_std=dd.log_gdp_std.values,
)
az8_3 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()})
az.compare({"m8.1": az8_1, "m8.2": az8_2, "m8.3": az8_3}, ic="waic", scale="deviance")
UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details
rank | waic | p_waic | d_waic | weight | se | dse | warning | waic_scale | |
---|---|---|---|---|---|---|---|---|---|
m8.3 | 0 | -259.176 | 5.10348 | 0 | 0.824888 | 13.4328 | 0 | True | deviance |
m8.2 | 1 | -252.36 | 4.15389 | 6.81647 | 0.175111 | 14.6901 | 6.67691 | True | deviance |
m8.1 | 2 | -188.818 | 2.65329 | 70.3582 | 4.90447e-08 | 14.6588 | 15.3423 | False | deviance |
waic_list = az.waic(az8_3, pointwise=True, scale="deviance").waic_i.values
UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail. See http://arxiv.org/abs/1507.04544 for details
# plot non-Africa - cid=1
d_A0 = dd[dd["cid"] == 1]
az.plot_pair(d_A0[["rugged_std", "log_gdp_std"]].to_dict(orient="list"))
plt.gca().set(
xlim=(-0.01, 1.01),
xlabel="ruggedness (standardized)",
ylabel="log GDP (as proportion of mean)",
)
mu = predictive(random.PRNGKey(2), cid=1, rugged_std=rugged_seq)["mu"]
mu_mean = jnp.mean(mu, 0)
mu_ci = jnp.percentile(mu, q=jnp.array([1.5, 98.5]), axis=0)
plt.plot(rugged_seq, mu_mean, "k")
plt.fill_between(rugged_seq, mu_ci[0], mu_ci[1], color="k", alpha=0.2)
plt.title("Non-African nations")
plt.show()
rugged_seq = jnp.linspace(start=-0.2, stop=1.2, num=30)
post = m8_3.sample_posterior(random.PRNGKey(1), p8_3, sample_shape=(1000,))
post.pop("mu")
predictive = Predictive(m8_3.model, post, return_sites=["mu"])
muA = predictive(random.PRNGKey(2), cid=0, rugged_std=rugged_seq)["mu"]
muN = predictive(random.PRNGKey(2), cid=1, rugged_std=rugged_seq)["mu"]
delta = muA - muN
tulips = pd.read_csv("../data/tulips.csv", sep=";")
d = tulips
d.info()
d.head()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 27 entries, 0 to 26 Data columns (total 4 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 bed 27 non-null object 1 water 27 non-null int64 2 shade 27 non-null int64 3 blooms 27 non-null float64 dtypes: float64(1), int64(2), object(1) memory usage: 992.0+ bytes
bed | water | shade | blooms | |
---|---|---|---|---|
0 | a | 1 | 1 | 0.00 |
1 | a | 1 | 2 | 0.00 |
2 | a | 1 | 3 | 111.04 |
3 | a | 2 | 1 | 183.47 |
4 | a | 2 | 2 | 59.16 |
d["blooms_std"] = d.blooms / d.blooms.max()
d["water_cent"] = d.water - d.water.mean()
d["shade_cent"] = d.shade - d.shade.mean()
a = dist.Normal(0.5, 1).sample(random.PRNGKey(0), (int(1e4),))
jnp.sum((a < 0) | (a > 1)) / a.shape[0]
DeviceArray(0.6182, dtype=float32)
a = dist.Normal(0.5, 0.25).sample(random.PRNGKey(0), (int(1e4),))
jnp.sum((a < 0) | (a > 1)) / a.shape[0]
DeviceArray(0.0471, dtype=float32)
def model(water_cent, shade_cent, blooms_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 0.25))
bw = numpyro.sample("bw", dist.Normal(0, 0.25))
bs = numpyro.sample("bs", dist.Normal(0, 0.25))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + bw * water_cent + bs * shade_cent)
numpyro.sample("blooms_std", dist.Normal(mu, sigma), obs=blooms_std)
m8_4 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m8_4,
optim.Adam(1),
Trace_ELBO(),
shade_cent=d.shade_cent.values,
water_cent=d.water_cent.values,
blooms_std=d.blooms_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_4 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1306.10it/s, init loss: 753.9799, avg. loss [951-1000]: -9.9633]
def model(water_cent, shade_cent, blooms_std=None):
a = numpyro.sample("a", dist.Normal(0.5, 0.25))
bw = numpyro.sample("bw", dist.Normal(0, 0.25))
bs = numpyro.sample("bs", dist.Normal(0, 0.25))
bws = numpyro.sample("bws", dist.Normal(0, 0.25))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bw * water_cent + bs * shade_cent + bws * water_cent * shade_cent
numpyro.sample("blooms_std", dist.Normal(mu, sigma), obs=blooms_std)
m8_5 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m8_5,
optim.Adam(1),
Trace_ELBO(),
shade_cent=d.shade_cent.values,
water_cent=d.water_cent.values,
blooms_std=d.blooms_std.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p8_5 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1205.86it/s, init loss: 133.8938, avg. loss [951-1000]: -16.3747]
_, axes = plt.subplots(1, 3, figsize=(9, 3), sharey=True) # 3 plots in 1 row
for ax, s in zip(axes, range(-1, 2)):
idx = d.shade_cent == s
ax.scatter(d.water_cent[idx], d.blooms_std[idx])
ax.set(xlim=(-1.1, 1.1), ylim=(-0.1, 1.1), xlabel="water", ylabel="blooms")
post = m8_4.sample_posterior(random.PRNGKey(1), p8_4, sample_shape=(1000,))
post.pop("mu")
mu = Predictive(m8_4.model, post, return_sites=["mu"])(
random.PRNGKey(2), shade_cent=s, water_cent=jnp.arange(-1, 2)
)["mu"]
for i in range(20):
ax.plot(range(-1, 2), mu[i], "k", alpha=0.3)
predictive = Predictive(
m8_5.model, num_samples=1000, return_sites=["a", "bw", "bs", "bws", "sigma"]
)
prior = predictive(random.PRNGKey(7), water_cent=0, shade_cent=0)
nettle = pd.read_csv("../data/nettle.csv", sep=";")
d = nettle
d["lang.per.cap"] = d["num.lang"] / d["k.pop"]