!pip install -q numpyro arviz
import os
import warnings
import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
from scipy.interpolate import BSpline
from scipy.stats import gaussian_kde
import jax.numpy as jnp
from jax import random, vmap
import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import hpdi, print_summary
from numpyro.infer import Predictive, SVI, Trace_ELBO, init_to_value
from numpyro.infer.autoguide import AutoLaplaceApproximation
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format(
category.__name__, message
)
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
pos = jnp.sum(dist.Uniform(-1, 1).sample(random.PRNGKey(0), (1000, 16)), -1)
jnp.prod(1 + dist.Uniform(0, 0.1).sample(random.PRNGKey(0), (12,)))
DeviceArray(1.7294353, dtype=float32)
growth = jnp.prod(1 + dist.Uniform(0, 0.1).sample(random.PRNGKey(0), (1000, 12)), -1)
az.plot_density({"growth": growth}, hdi_prob=1)
x = jnp.sort(growth)
plt.plot(x, jnp.exp(dist.Normal(jnp.mean(x), jnp.std(x)).log_prob(x)), "--")
plt.show()
big = jnp.prod(1 + dist.Uniform(0, 0.5).sample(random.PRNGKey(0), (1000, 12)), -1)
small = jnp.prod(1 + dist.Uniform(0, 0.01).sample(random.PRNGKey(0), (1000, 12)), -1)
log_big = jnp.log(
jnp.prod(1 + dist.Uniform(0, 0.5).sample(random.PRNGKey(0), (1000, 12)), -1)
)
w = 6
n = 9
p_grid = jnp.linspace(start=0, stop=1, num=100)
prob_binom = jnp.exp(dist.Binomial(n, p_grid).log_prob(w))
posterior = prob_binom * jnp.exp(dist.Uniform(0, 1).log_prob(p_grid))
posterior = posterior / jnp.sum(posterior)
Howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
d = Howell1
d.info()
d.head()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 544 entries, 0 to 543 Data columns (total 4 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 height 544 non-null float64 1 weight 544 non-null float64 2 age 544 non-null float64 3 male 544 non-null int64 dtypes: float64(3), int64(1) memory usage: 17.1 KB
height | weight | age | male | |
---|---|---|---|---|
0 | 151.765 | 47.825606 | 63.0 | 1 |
1 | 139.700 | 36.485807 | 63.0 | 0 |
2 | 136.525 | 31.864838 | 65.0 | 0 |
3 | 156.845 | 53.041914 | 41.0 | 1 |
4 | 145.415 | 41.276872 | 51.0 | 0 |
print_summary(dict(zip(d.columns, d.T.values)), 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat age 29.34 20.75 27.00 0.00 57.00 186.38 1.03 height 138.26 27.60 148.59 90.81 170.18 218.68 1.06 male 0.47 0.50 0.00 0.00 1.00 670.75 1.00 weight 35.61 14.72 40.06 11.37 55.71 305.62 1.05
d.height
0 151.765 1 139.700 2 136.525 3 156.845 4 145.415 ... 539 145.415 540 162.560 541 156.210 542 71.120 543 158.750 Name: height, Length: 544, dtype: float64
d2 = d[d.age >= 18]
x = jnp.linspace(100, 250, 101)
plt.plot(x, jnp.exp(dist.Normal(178, 20).log_prob(x)))
plt.show()
x = jnp.linspace(-10, 60, 101)
plt.plot(x, jnp.exp(dist.Uniform(0, 50, validate_args=True).log_prob(x)))
plt.show()
UserWarning: Out-of-support values provided to log prob method. The value argument should be within the support.
sample_mu = dist.Normal(178, 20).sample(random.PRNGKey(0), (int(1e4),))
sample_sigma = dist.Uniform(0, 50).sample(random.PRNGKey(1), (int(1e4),))
prior_h = dist.Normal(sample_mu, sample_sigma).sample(random.PRNGKey(2))
az.plot_kde(prior_h)
plt.show()
sample_mu = dist.Normal(178, 100).sample(random.PRNGKey(0), (int(1e4),))
prior_h = dist.Normal(sample_mu, sample_sigma).sample(random.PRNGKey(2))
az.plot_kde(prior_h)
plt.show()
mu_list = jnp.linspace(start=150, stop=160, num=100)
sigma_list = jnp.linspace(start=7, stop=9, num=100)
mesh = jnp.meshgrid(mu_list, sigma_list)
post = {"mu": mesh[0].reshape(-1), "sigma": mesh[1].reshape(-1)}
post["LL"] = vmap(
lambda mu, sigma: jnp.sum(dist.Normal(mu, sigma).log_prob(d2.height.values))
)(post["mu"], post["sigma"])
logprob_mu = dist.Normal(178, 20).log_prob(post["mu"])
logprob_sigma = dist.Uniform(0, 50).log_prob(post["sigma"])
post["prob"] = post["LL"] + logprob_mu + logprob_sigma
post["prob"] = jnp.exp(post["prob"] - jnp.max(post["prob"]))
plt.contour(
post["mu"].reshape(100, 100),
post["sigma"].reshape(100, 100),
post["prob"].reshape(100, 100),
)
plt.show()
plt.imshow(
post["prob"].reshape(100, 100),
origin="lower",
extent=(150, 160, 7, 9),
aspect="auto",
)
plt.show()
prob = post["prob"] / jnp.sum(post["prob"])
sample_rows = dist.Categorical(probs=prob).sample(random.PRNGKey(0), (int(1e4),))
sample_mu = post["mu"][sample_rows]
sample_sigma = post["sigma"][sample_rows]
plt.scatter(sample_mu, sample_sigma, s=64, alpha=0.1, edgecolor="none")
plt.show()
az.plot_kde(sample_mu)
plt.show()
az.plot_kde(sample_sigma)
plt.show()
print(hpdi(sample_mu, 0.89))
print(hpdi(sample_sigma, 0.89))
[153.93939 155.15152] [7.3232327 8.252525 ]
d3 = d2.height.sample(n=20)
mu_list = jnp.linspace(start=150, stop=170, num=200)
sigma_list = jnp.linspace(start=4, stop=20, num=200)
mesh = jnp.meshgrid(mu_list, sigma_list)
post2 = {"mu": mesh[0].reshape(-1), "sigma": mesh[1].reshape(-1)}
post2["LL"] = vmap(
lambda mu, sigma: jnp.sum(dist.Normal(mu, sigma).log_prob(d3.values))
)(post2["mu"], post2["sigma"])
logprob_mu = dist.Normal(178, 20).log_prob(post2["mu"])
logprob_sigma = dist.Uniform(0, 50).log_prob(post2["sigma"])
post2["prob"] = post2["LL"] + logprob_mu + logprob_sigma
post2["prob"] = jnp.exp(post2["prob"] - jnp.max(post2["prob"]))
prob = post2["prob"] / jnp.sum(post2["prob"])
sample2_rows = dist.Categorical(probs=prob).sample(random.PRNGKey(0), (int(1e4),))
sample2_mu = post2["mu"][sample2_rows]
sample2_sigma = post2["sigma"][sample2_rows]
plt.scatter(sample2_mu, sample2_sigma, s=64, alpha=0.1, edgecolor="none")
plt.show()
az.plot_kde(sample2_sigma)
x = jnp.sort(sample2_sigma)
plt.plot(x, jnp.exp(dist.Normal(jnp.mean(x), jnp.std(x)).log_prob(x)), "--")
plt.show()
Howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
d = Howell1
d2 = d[d["age"] >= 18]
def flist(height):
mu = numpyro.sample("mu", dist.Normal(178, 20))
sigma = numpyro.sample("sigma", dist.Uniform(0, 50))
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
m4_1 = AutoLaplaceApproximation(flist)
svi = SVI(flist, m4_1, optim.Adam(1), Trace_ELBO(), height=d2.height.values)
svi_result = svi.run(random.PRNGKey(0), 2000)
p4_1 = svi_result.params
100%|██████████| 2000/2000 [00:00<00:00, 2242.78it/s, init loss: 4000.1155, avg. loss [1901-2000]: 1226.0389]
samples = m4_1.sample_posterior(random.PRNGKey(1), p4_1, sample_shape=(1000,))
print_summary(samples, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat mu 154.60 0.40 154.60 154.00 155.28 995.06 1.00 sigma 7.76 0.30 7.76 7.33 8.26 1007.15 1.00
start = {"mu": d2.height.mean(), "sigma": d2.height.std()}
m4_1 = AutoLaplaceApproximation(flist, init_loc_fn=init_to_value(values=start))
svi = SVI(flist, m4_1, optim.Adam(0.1), Trace_ELBO(), height=d2.height.values)
svi_result = svi.run(random.PRNGKey(0), 2000)
p4_1 = svi_result.params
100%|██████████| 2000/2000 [00:00<00:00, 2192.68it/s, init loss: 1226.0386, avg. loss [1901-2000]: 1226.0389]
def model(height):
mu = numpyro.sample("mu", dist.Normal(178, 0.1))
sigma = numpyro.sample("sigma", dist.Uniform(0, 50))
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
m4_2 = AutoLaplaceApproximation(model)
svi = SVI(model, m4_2, optim.Adam(1), Trace_ELBO(), height=d2.height.values)
svi_result = svi.run(random.PRNGKey(0), 2000)
p4_2 = svi_result.params
samples = m4_2.sample_posterior(random.PRNGKey(1), p4_2, sample_shape=(1000,))
print_summary(samples, 0.89, False)
100%|██████████| 2000/2000 [00:00<00:00, 2169.41it/s, init loss: 1584193.6250, avg. loss [1901-2000]: 1626.5828]
mean std median 5.5% 94.5% n_eff r_hat mu 177.86 0.10 177.86 177.72 178.03 995.05 1.00 sigma 24.57 0.94 24.60 23.01 25.96 1012.88 1.00
samples = m4_1.sample_posterior(random.PRNGKey(1), p4_1, sample_shape=(1000,))
vcov = jnp.cov(jnp.stack(list(samples.values()), axis=0))
vcov
DeviceArray([[0.16249049, 0.00181413], [0.00181413, 0.08733188]], dtype=float32)
print(jnp.diagonal(vcov))
print(vcov / jnp.sqrt(jnp.outer(jnp.diagonal(vcov), jnp.diagonal(vcov))))
[0.16249049 0.08733188] [[1. 0.0152289] [0.0152289 1. ]]
post = m4_1.sample_posterior(random.PRNGKey(1), p4_1, sample_shape=(int(1e4),))
{latent: list(post[latent][:6]) for latent in post}
{'mu': [154.24832, 154.48946, 154.98318, 154.21646, 155.49542, 154.83102], 'sigma': [7.55951, 7.3061066, 7.280058, 7.810999, 7.905513, 7.9781823]}
print_summary(post, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat mu 154.61 0.41 154.61 153.94 155.25 9926.98 1.00 sigma 7.75 0.29 7.74 7.28 8.22 9502.46 1.00
samples_flat = jnp.stack(list(post.values()))
mu, sigma = jnp.mean(samples_flat, axis=1), jnp.cov(samples_flat)
post = dist.MultivariateNormal(mu, sigma).sample(random.PRNGKey(0), (int(1e4),))
az.plot_pair(d2[["weight", "height"]].to_dict(orient="list"))
plt.show()
with numpyro.handlers.seed(rng_seed=2971):
N = 100 # 100 lines
a = numpyro.sample("a", dist.Normal(178, 20).expand([N]))
b = numpyro.sample("b", dist.Normal(0, 10).expand([N]))
plt.subplot(
xlim=(d2.weight.min(), d2.weight.max()),
ylim=(-100, 400),
xlabel="weight",
ylabel="height",
)
plt.axhline(y=0, c="k", ls="--")
plt.axhline(y=272, c="k", ls="-", lw=0.5)
plt.title("b ~ Normal(0, 10)")
xbar = d2.weight.mean()
x = jnp.linspace(d2.weight.min(), d2.weight.max(), 101)
for i in range(N):
plt.plot(x, a[i] + b[i] * (x - xbar), "k", alpha=0.2)
plt.show()
b = dist.LogNormal(0, 1).sample(random.PRNGKey(0), (int(1e4),))
az.plot_kde(b)
plt.show()
with numpyro.handlers.seed(rng_seed=2971):
N = 100 # 100 lines
a = numpyro.sample("a", dist.Normal(178, 28).expand([N]))
b = numpyro.sample("b", dist.LogNormal(0, 1).expand([N]))
# load data again, since it's a long way back
Howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
d = Howell1
d2 = d[d["age"] >= 18]
# define the average weight, x-bar
xbar = d2.weight.mean()
# fit model
def model(weight, height):
a = numpyro.sample("a", dist.Normal(178, 20))
b = numpyro.sample("b", dist.LogNormal(0, 1))
sigma = numpyro.sample("sigma", dist.Uniform(0, 50))
mu = numpyro.deterministic("mu", a + b * (weight - xbar))
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
m4_3 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m4_3,
optim.Adam(1),
Trace_ELBO(),
weight=d2.weight.values,
height=d2.height.values,
)
svi_result = svi.run(random.PRNGKey(0), 2000)
p4_3 = svi_result.params
100%|██████████| 2000/2000 [00:01<00:00, 1910.96it/s, init loss: 40631.5430, avg. loss [1901-2000]: 1078.9297]
def model(weight, height=None):
a = numpyro.sample("a", dist.Normal(178, 20))
log_b = numpyro.sample("log_b", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Uniform(0, 50))
mu = numpyro.deterministic("mu", a + jnp.exp(log_b) * (weight - xbar))
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
m4_3b = AutoLaplaceApproximation(model)
svi = SVI(
model,
m4_3b,
optim.Adam(1),
Trace_ELBO(),
weight=d2.weight.values,
height=d2.height.values,
)
svi_result = svi.run(random.PRNGKey(0), 2000)
p4_3b = svi_result.params
100%|██████████| 2000/2000 [00:01<00:00, 1860.03it/s, init loss: 40631.5430, avg. loss [1901-2000]: 1078.9297]
samples = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))
samples.pop("mu")
print_summary(samples, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat a 154.62 0.27 154.63 154.16 155.03 931.50 1.00 b 0.91 0.04 0.90 0.84 0.97 1083.74 1.00 sigma 5.08 0.19 5.08 4.79 5.41 949.65 1.00
vcov = jnp.cov(jnp.stack(list(samples.values()), axis=0))
jnp.round(vcov, 3)
DeviceArray([[0.075, 0. , 0.001], [0. , 0.002, 0. ], [0.001, 0. , 0.038]], dtype=float32)
az.plot_pair(d2[["weight", "height"]].to_dict(orient="list"))
post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))
a_map = jnp.mean(post["a"])
b_map = jnp.mean(post["b"])
x = jnp.linspace(d2.weight.min(), d2.weight.max(), 101)
plt.plot(x, a_map + b_map * (x - xbar), "k")
plt.show()
post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))
{latent: list(post[latent].reshape(-1)[:5]) for latent in post}
{'a': [154.36615, 154.78511, 154.73534, 154.53842, 154.53549], 'b': [0.974645, 0.8900048, 0.81902206, 0.8334104, 1.011918], 'mu': [157.12938, 146.0771, 141.5733, 162.21344, 150.74669], 'sigma': [4.9764595, 4.94353, 5.2826037, 4.877722, 4.89487]}
N = 10
dN = d2[:N]
def model(weight, height):
a = numpyro.sample("a", dist.Normal(178, 20))
b = numpyro.sample("b", dist.LogNormal(0, 1))
sigma = numpyro.sample("sigma", dist.Uniform(0, 50))
mu = a + b * (weight - jnp.mean(weight))
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
mN = AutoLaplaceApproximation(model)
svi = SVI(
model,
mN,
optim.Adam(1),
Trace_ELBO(),
weight=dN.weight.values,
height=dN.height.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
pN = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1231.83it/s, init loss: 1233.4979, avg. loss [951-1000]: 37.0481]
# extract 20 samples from the posterior
post = mN.sample_posterior(random.PRNGKey(1), pN, sample_shape=(20,))
# display raw data and sample size
ax = az.plot_pair(dN[["weight", "height"]].to_dict(orient="list"))
ax.set(
xlim=(d2.weight.min(), d2.weight.max()),
ylim=(d2.height.min(), d2.height.max()),
title="N = {}".format(N),
)
# plot the lines, with transparency
x = jnp.linspace(d2.weight.min(), d2.weight.max(), 101)
for i in range(20):
plt.plot(x, post["a"][i] + post["b"][i] * (x - dN.weight.mean()), "k", alpha=0.3)
post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))
post.pop("mu")
mu_at_50 = post["a"] + post["b"] * (50 - xbar)
az.plot_kde(mu_at_50, label="mu|weight=50")
plt.show()
jnp.percentile(mu_at_50, q=jnp.array([5.5, 94.5]))
DeviceArray([158.5957 , 159.71445], dtype=float32)
mu = Predictive(m4_3.model, post, return_sites=["mu"])(
random.PRNGKey(2), d2.weight.values, d2.height.values
)["mu"]
mu.shape, list(mu[:5, 0])
((1000, 352), [157.12938, 157.30838, 157.05736, 156.90125, 157.4044])
# define sequence of weights to compute predictions for
# these values will be on the horizontal axis
weight_seq = jnp.arange(start=25, stop=71, step=1)
# use predictive to compute mu
# for each sample from posterior
# and for each weight in weight_seq
mu = Predictive(m4_3.model, post, return_sites=["mu"])(
random.PRNGKey(2), weight_seq, None
)["mu"]
mu.shape, list(mu[:5, 0])
((1000, 46), [134.88252, 136.99348, 138.36269, 137.87814, 134.30676])
# use scatter_kwargs={"alpha": 0} to hide raw data
az.plot_pair(
d2[["weight", "height"]].to_dict(orient="list"), scatter_kwargs={"alpha": 0}
)
# loop over samples and plot each mu value
for i in range(100):
plt.plot(weight_seq, mu[i], "o", c="royalblue", alpha=0.1)
# summarize the distribution of mu
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0)
# plot raw data
# fading out points to make line and interval more visible
az.plot_pair(
d2[["weight", "height"]].to_dict(orient="list"), scatter_kwargs={"alpha": 0.5}
)
# plot the MAP line, aka the mean mu for each weight
plt.plot(weight_seq, mu_mean, "k")
# plot a shaded region for 89% PI
plt.fill_between(weight_seq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
plt.show()
post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))
mu_link = lambda weight: post["a"] + post["b"] * (weight - xbar)
weight_seq = jnp.arange(start=25, stop=71, step=1)
mu = vmap(mu_link)(weight_seq).T
mu_mean = jnp.mean(mu, 0)
mu_HPDI = hpdi(mu, prob=0.89, axis=0)
post.pop("mu")
sim_height = Predictive(m4_3.model, post, return_sites=["height"])(
random.PRNGKey(2), weight_seq, None
)["height"]
sim_height.shape, list(sim_height[:5, 0])
((1000, 46), [135.85771, 137.52162, 133.89777, 138.14607, 131.1664])
height_PI = jnp.percentile(sim_height, q=jnp.array([5.5, 94.5]), axis=0)
# plot raw data
az.plot_pair(
d2[["weight", "height"]].to_dict(orient="list"), scatter_kwargs={"alpha": 0.5}
)
# draw MAP line
plt.plot(weight_seq, mu_mean, "k")
# draw HPDI region for line
plt.fill_between(weight_seq, mu_HPDI[0], mu_HPDI[1], color="k", alpha=0.2)
# draw PI region for simulated heights
plt.fill_between(weight_seq, height_PI[0], height_PI[1], color="k", alpha=0.15)
plt.show()
post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(int(1e4),))
post.pop("mu")
sim_height = Predictive(m4_3.model, post, return_sites=["height"])(
random.PRNGKey(2), weight_seq, None
)["height"]
height_PI = jnp.percentile(sim_height, q=jnp.array([5.5, 94.5]), axis=0)
post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))
weight_seq = jnp.arange(25, 71)
sim_height = vmap(
lambda i, weight: dist.Normal(
post["a"] + post["b"] * (weight - xbar), post["sigma"]
).sample(random.PRNGKey(i))
)(jnp.arange(len(weight_seq)), weight_seq).T
height_PI = jnp.percentile(sim_height, q=jnp.array([5.5, 94.5]), axis=0)
Howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
d = Howell1
d.info()
d.head()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 544 entries, 0 to 543 Data columns (total 4 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 height 544 non-null float64 1 weight 544 non-null float64 2 age 544 non-null float64 3 male 544 non-null int64 dtypes: float64(3), int64(1) memory usage: 17.1 KB
height | weight | age | male | |
---|---|---|---|---|
0 | 151.765 | 47.825606 | 63.0 | 1 |
1 | 139.700 | 36.485807 | 63.0 | 0 |
2 | 136.525 | 31.864838 | 65.0 | 0 |
3 | 156.845 | 53.041914 | 41.0 | 1 |
4 | 145.415 | 41.276872 | 51.0 | 0 |
d["weight_s"] = (d.weight - d.weight.mean()) / d.weight.std()
d["weight_s2"] = d.weight_s**2
def model(weight_s, weight_s2, height=None):
a = numpyro.sample("a", dist.Normal(178, 20))
b1 = numpyro.sample("b1", dist.LogNormal(0, 1))
b2 = numpyro.sample("b2", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Uniform(0, 50))
mu = numpyro.deterministic("mu", a + b1 * weight_s + b2 * weight_s2)
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
m4_5 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m4_5,
optim.Adam(0.3),
Trace_ELBO(),
weight_s=d.weight_s.values,
weight_s2=d.weight_s2.values,
height=d.height.values,
)
svi_result = svi.run(random.PRNGKey(0), 3000)
p4_5 = svi_result.params
100%|██████████| 3000/3000 [00:01<00:00, 2115.79it/s, init loss: 68267.6406, avg. loss [2851-3000]: 1770.2694]
samples = m4_5.sample_posterior(random.PRNGKey(1), p4_5, sample_shape=(1000,))
print_summary({k: v for k, v in samples.items() if k != "mu"}, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat a 146.05 0.36 146.03 145.47 146.58 1049.96 1.00 b1 21.75 0.30 21.75 21.25 22.18 886.88 1.00 b2 -7.79 0.28 -7.79 -8.21 -7.32 1083.62 1.00 sigma 5.78 0.17 5.78 5.49 6.02 973.22 1.00
weight_seq = jnp.linspace(start=-2.2, stop=2, num=30)
pred_dat = {"weight_s": weight_seq, "weight_s2": weight_seq**2}
post = m4_5.sample_posterior(random.PRNGKey(1), p4_5, sample_shape=(1000,))
post.pop("mu")
predictive = Predictive(m4_5.model, post)
mu = predictive(random.PRNGKey(2), **pred_dat)["mu"]
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=jnp.array([5.5, 94.5]), axis=0)
sim_height = predictive(random.PRNGKey(3), **pred_dat)["height"]
height_PI = jnp.percentile(sim_height, q=jnp.array([5.5, 94.5]), axis=0)
az.plot_pair(
d[["weight_s", "height"]].to_dict(orient="list"), scatter_kwargs={"alpha": 0.5}
)
plt.plot(weight_seq, mu_mean, "k")
plt.fill_between(weight_seq, mu_PI[0], mu_PI[1], color="k", alpha=0.2)
plt.fill_between(weight_seq, height_PI[0], height_PI[1], color="k", alpha=0.15)
plt.show()
d["weight_s3"] = d.weight_s**3
def model(weight_s, weight_s2, weight_s3, height):
a = numpyro.sample("a", dist.Normal(178, 20))
b1 = numpyro.sample("b1", dist.LogNormal(0, 1))
b2 = numpyro.sample("b2", dist.Normal(0, 1))
b3 = numpyro.sample("b3", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Uniform(0, 50))
mu = a + b1 * weight_s + b2 * weight_s2 + b3 * weight_s3
numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
m4_6 = AutoLaplaceApproximation(model)
svi = SVI(
model,
m4_6,
optim.Adam(0.3),
Trace_ELBO(),
weight_s=d.weight_s.values,
weight_s2=d.weight_s2.values,
weight_s3=d.weight_s3.values,
height=d.height.values,
)
svi_result = svi.run(random.PRNGKey(0), 1000)
p4_6 = svi_result.params
100%|██████████| 1000/1000 [00:00<00:00, 1055.90it/s, init loss: 5544.0449, avg. loss [951-1000]: 1668.3276]
ax = az.plot_pair(
d[["weight_s", "height"]].to_dict(orient="list"), scatter_kwargs={"alpha": 0.5}
)
ax.set(xlabel="weight", ylabel="height", xticks=[])
fig = plt.gcf()
at = jnp.array([-2, -1, 0, 1, 2])
labels = at * d.weight.std() + d.weight.mean()
ax.set_xticks(at)
ax.set_xticklabels([round(label, 1) for label in labels])
fig
cherry_blossoms = pd.read_csv("../data/cherry_blossoms.csv", sep=";")
d = cherry_blossoms
print_summary({"year": d.year.dropna().values}, 0.89, False)
print_summary({"doy": d.doy.dropna().values}, 0.89, False)
print_summary({"temp": d.temp.dropna().values}, 0.89, False)
print_summary({"temp_upper": d.temp_upper.dropna().values}, 0.89, False)
print_summary({"temp_lower": d.temp_lower.dropna().values}, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat year 1408.00 350.88 1408.00 801.00 1882.00 2.51 2.65 mean std median 5.5% 94.5% n_eff r_hat doy 104.54 6.41 105.00 93.00 113.00 111.98 1.00 mean std median 5.5% 94.5% n_eff r_hat temp 6.14 0.66 6.10 5.03 7.13 22.07 1.02 mean std median 5.5% 94.5% n_eff r_hat temp_upper 7.19 0.99 7.04 5.66 8.54 10.11 1.24 mean std median 5.5% 94.5% n_eff r_hat temp_lower 5.10 0.85 5.14 3.79 6.37 21.90 1.11
d2 = d[d.doy.notna()] # complete cases on doy
num_knots = 15
knot_list = jnp.quantile(
d2.year.values.astype(float), q=jnp.linspace(0, 1, num=num_knots)
)
knots = jnp.pad(knot_list, (3, 3), mode="edge")
B = BSpline(knots, jnp.identity(num_knots + 2), k=3)(d2.year.values)
plt.subplot(
xlim=(d2.year.min(), d2.year.max()),
ylim=(0, 1),
xlabel="year",
ylabel="basis value",
)
for i in range(B.shape[1]):
plt.plot(d2.year, B[:, i], "k", alpha=0.5)
def model(B, D):
a = numpyro.sample("a", dist.Normal(100, 10))
w = numpyro.sample("w", dist.Normal(0, 10).expand(B.shape[1:]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + B @ w)
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
start = {"w": jnp.zeros(B.shape[1])}
m4_7 = AutoLaplaceApproximation(model, init_loc_fn=init_to_value(values=start))
svi = SVI(model, m4_7, optim.Adam(1), Trace_ELBO(), B=B, D=d2.doy.values)
svi_result = svi.run(random.PRNGKey(0), 20000)
p4_7 = svi_result.params
100%|██████████| 20000/20000 [00:05<00:00, 3960.97it/s, init loss: 190557520.0000, avg. loss [19001-20000]: 2704.7454]
post = m4_7.sample_posterior(random.PRNGKey(1), p4_7, sample_shape=(1000,))
w = jnp.mean(post["w"], 0)
plt.subplot(
xlim=(d2.year.min(), d2.year.max()),
ylim=(-6, 6),
xlabel="year",
ylabel="basis * weight",
)
for i in range(B.shape[1]):
plt.plot(d2.year, (w[i] * B[:, i]), "k", alpha=0.5)
mu = post["mu"]
mu_PI = jnp.percentile(mu, q=jnp.array([1.5, 98.5]), axis=0)
az.plot_pair(
d2[["year", "doy"]].astype(float).to_dict(orient="list"),
scatter_kwargs={"c": "royalblue", "alpha": 0.3, "s": 10},
)
plt.fill_between(d2.year, mu_PI[0], mu_PI[1], color="k", alpha=0.5)
plt.show()
def model(B, D):
a = numpyro.sample("a", dist.Normal(100, 10))
w = numpyro.sample("w", dist.Normal(0, 10).expand(B.shape[1:]))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = numpyro.deterministic("mu", a + jnp.sum(B * w, axis=-1))
numpyro.sample("D", dist.Normal(mu, sigma), obs=D)
start = {"w": jnp.zeros(B.shape[1])}
m4_7alt = AutoLaplaceApproximation(model, init_loc_fn=init_to_value(values=start))
svi = SVI(model, m4_7alt, optim.Adam(1), Trace_ELBO(), B=B, D=d2.doy.values)
svi_result = svi.run(random.PRNGKey(0), 20000)
p4_7alt = svi_result.params
100%|██████████| 20000/20000 [00:05<00:00, 3976.15it/s, init loss: 190557520.0000, avg. loss [19001-20000]: 2704.7454]