#!/usr/bin/env python # coding: utf-8 # # Chapter 8. Conditional Manatees # In[ ]: get_ipython().system('pip install -q numpyro arviz') # In[1]: import math import os import warnings import arviz as az import matplotlib.pyplot as plt import pandas as pd import jax.numpy as jnp from jax import random import numpyro import numpyro.distributions as dist import numpyro.optim as optim from numpyro.diagnostics import print_summary from numpyro.infer import Predictive, SVI, Trace_ELBO, log_likelihood from numpyro.infer.autoguide import AutoLaplaceApproximation if "SVG" in os.environ: get_ipython().run_line_magic('config', 'InlineBackend.figure_formats = ["svg"]') warnings.formatwarning = lambda message, category, *args, **kwargs: "{}: {}\n".format( category.__name__, message ) az.style.use("arviz-darkgrid") numpyro.set_platform("cpu") # ### Code 8.1 # In[2]: rugged = pd.read_csv("../data/rugged.csv", sep=";") d = rugged # make log version of outcome d["log_gdp"] = d["rgdppc_2000"].apply(math.log) # extract countries with GDP data dd = d[d["rgdppc_2000"].notnull()].copy() # rescale variables dd["log_gdp_std"] = dd.log_gdp / dd.log_gdp.mean() dd["rugged_std"] = dd.rugged / dd.rugged.max() # ### Code 8.2 # In[3]: def model(rugged_std, log_gdp_std=None): a = numpyro.sample("a", dist.Normal(1, 1)) b = numpyro.sample("b", dist.Normal(0, 1)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a + b * (rugged_std - 0.215)) numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std) m8_1 = AutoLaplaceApproximation(model) svi = SVI( model, m8_1, optim.Adam(0.1), Trace_ELBO(), rugged_std=dd.rugged_std.values, log_gdp_std=dd.log_gdp_std.values, ) svi_result = svi.run(random.PRNGKey(0), 1000) p8_1 = svi_result.params # ### Code 8.3 # In[4]: predictive = Predictive(m8_1.model, num_samples=1000, return_sites=["a", "b", "sigma"]) prior = predictive(random.PRNGKey(7), rugged_std=0) # set up the plot dimensions plt.subplot(xlim=(0, 1), ylim=(0.5, 1.5), xlabel="ruggedness", ylabel="log GDP") plt.gca().axhline(dd.log_gdp_std.min(), ls="--") plt.gca().axhline(dd.log_gdp_std.max(), ls="--") # draw 50 lines from the prior rugged_seq = jnp.linspace(-0.1, 1.1, num=30) mu = Predictive(m8_1.model, prior, return_sites=["mu"])( random.PRNGKey(7), rugged_std=rugged_seq )["mu"] for i in range(50): plt.plot(rugged_seq, mu[i], "k", alpha=0.3) # ### Code 8.4 # In[5]: jnp.sum(jnp.abs(prior["b"]) > 0.6) / prior["b"].shape[0] # ### Code 8.5 # In[6]: def model(rugged_std, log_gdp_std=None): a = numpyro.sample("a", dist.Normal(1, 0.1)) b = numpyro.sample("b", dist.Normal(0, 0.3)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a + b * (rugged_std - 0.215)) numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std) m8_1 = AutoLaplaceApproximation(model) svi = SVI( model, m8_1, optim.Adam(0.1), Trace_ELBO(), rugged_std=dd.rugged_std.values, log_gdp_std=dd.log_gdp_std.values, ) svi_result = svi.run(random.PRNGKey(0), 1000) p8_1 = svi_result.params # ### Code 8.6 # In[7]: post = m8_1.sample_posterior(random.PRNGKey(1), p8_1, sample_shape=(1000,)) print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False) # ### Code 8.7 # In[8]: # make variable to index Africa (0) or not (1) dd["cid"] = jnp.where(dd.cont_africa.values == 1, 0, 1) # ### Code 8.8 # In[9]: def model(cid, rugged_std, log_gdp_std=None): a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2])) b = numpyro.sample("b", dist.Normal(0, 0.3)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a[cid] + b * (rugged_std - 0.215)) numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std) m8_2 = AutoLaplaceApproximation(model) svi = SVI( model, m8_2, optim.Adam(0.1), Trace_ELBO(), cid=dd.cid.values, rugged_std=dd.rugged_std.values, log_gdp_std=dd.log_gdp_std.values, ) svi_result = svi.run(random.PRNGKey(0), 1000) p8_2 = svi_result.params # ### Code 8.9 # In[10]: post = m8_1.sample_posterior(random.PRNGKey(2), p8_1, sample_shape=(1000,)) logprob = log_likelihood( m8_1.model, post, rugged_std=dd.rugged_std.values, log_gdp_std=dd.log_gdp_std.values ) az8_1 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()}) post = m8_2.sample_posterior(random.PRNGKey(2), p8_2, sample_shape=(1000,)) logprob = log_likelihood( m8_2.model, post, rugged_std=dd.rugged_std.values, cid=dd.cid.values, log_gdp_std=dd.log_gdp_std.values, ) az8_2 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()}) az.compare({"m8.1": az8_1, "m8.2": az8_2}, ic="waic", scale="deviance") # ### Code 8.10 # In[11]: post = m8_2.sample_posterior(random.PRNGKey(1), p8_2, sample_shape=(1000,)) print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False) # ### Code 8.11 # In[12]: post = m8_2.sample_posterior(random.PRNGKey(1), p8_2, sample_shape=(1000,)) diff_a1_a2 = post["a"][:, 0] - post["a"][:, 1] jnp.percentile(diff_a1_a2, q=jnp.array([5.5, 94.5])) # ### Code 8.12 # In[13]: rugged_seq = jnp.linspace(start=-1, stop=1.1, num=30) # compute mu over samples, fixing cid=1 post.pop("mu") predictive = Predictive(m8_2.model, post, return_sites=["mu"]) mu_NotAfrica = predictive(random.PRNGKey(2), cid=1, rugged_std=rugged_seq)["mu"] # compute mu over samples, fixing cid=0 mu_Africa = predictive(random.PRNGKey(2), cid=0, rugged_std=rugged_seq)["mu"] # summarize to means and intervals mu_NotAfrica_mu = jnp.mean(mu_NotAfrica, 0) mu_NotAfrica_ci = jnp.percentile(mu_NotAfrica, q=jnp.array([1.5, 98.5]), axis=0) mu_Africa_mu = jnp.mean(mu_Africa, 0) mu_Africa_ci = jnp.percentile(mu_Africa, q=jnp.array([1.5, 98.5]), axis=0) # ### Code 8.13 # In[14]: def model(cid, rugged_std, log_gdp_std=None): a = numpyro.sample("a", dist.Normal(1, 0.1).expand([2])) b = numpyro.sample("b", dist.Normal(0, 0.3).expand([2])) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a[cid] + b[cid] * (rugged_std - 0.215)) numpyro.sample("log_gdp_std", dist.Normal(mu, sigma), obs=log_gdp_std) m8_3 = AutoLaplaceApproximation(model) svi = SVI( model, m8_3, optim.Adam(0.1), Trace_ELBO(), cid=dd.cid.values, rugged_std=dd.rugged_std.values, log_gdp_std=dd.log_gdp_std.values, ) svi_result = svi.run(random.PRNGKey(0), 1000) p8_3 = svi_result.params # ### Code 8.14 # In[15]: post = m8_3.sample_posterior(random.PRNGKey(1), p8_3, sample_shape=(1000,)) print_summary({k: v for k, v in post.items() if k != "mu"}, 0.89, False) # ### Code 8.15 # In[16]: post = m8_1.sample_posterior(random.PRNGKey(2), p8_1, sample_shape=(1000,)) logprob = log_likelihood( m8_1.model, post, rugged_std=dd.rugged_std.values, log_gdp_std=dd.log_gdp_std.values ) az8_1 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()}) post = m8_2.sample_posterior(random.PRNGKey(2), p8_2, sample_shape=(1000,)) logprob = log_likelihood( m8_2.model, post, rugged_std=dd.rugged_std.values, cid=dd.cid.values, log_gdp_std=dd.log_gdp_std.values, ) az8_3 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()}) post = m8_3.sample_posterior(random.PRNGKey(2), p8_3, sample_shape=(1000,)) logprob = log_likelihood( m8_3.model, post, rugged_std=dd.rugged_std.values, cid=dd.cid.values, log_gdp_std=dd.log_gdp_std.values, ) az8_3 = az.from_dict({}, log_likelihood={k: v[None] for k, v in logprob.items()}) az.compare({"m8.1": az8_1, "m8.2": az8_2, "m8.3": az8_3}, ic="waic", scale="deviance") # ### Code 8.16 # In[17]: waic_list = az.waic(az8_3, pointwise=True, scale="deviance").waic_i.values # ### Code 8.17 # In[18]: # plot non-Africa - cid=1 d_A0 = dd[dd["cid"] == 1] az.plot_pair(d_A0[["rugged_std", "log_gdp_std"]].to_dict(orient="list")) plt.gca().set( xlim=(-0.01, 1.01), xlabel="ruggedness (standardized)", ylabel="log GDP (as proportion of mean)", ) mu = predictive(random.PRNGKey(2), cid=1, rugged_std=rugged_seq)["mu"] mu_mean = jnp.mean(mu, 0) mu_ci = jnp.percentile(mu, q=jnp.array([1.5, 98.5]), axis=0) plt.plot(rugged_seq, mu_mean, "k") plt.fill_between(rugged_seq, mu_ci[0], mu_ci[1], color="k", alpha=0.2) plt.title("Non-African nations") plt.show() # ### Code 8.18 # In[19]: rugged_seq = jnp.linspace(start=-0.2, stop=1.2, num=30) post = m8_3.sample_posterior(random.PRNGKey(1), p8_3, sample_shape=(1000,)) post.pop("mu") predictive = Predictive(m8_3.model, post, return_sites=["mu"]) muA = predictive(random.PRNGKey(2), cid=0, rugged_std=rugged_seq)["mu"] muN = predictive(random.PRNGKey(2), cid=1, rugged_std=rugged_seq)["mu"] delta = muA - muN # ### Code 8.19 # In[20]: tulips = pd.read_csv("../data/tulips.csv", sep=";") d = tulips d.info() d.head() # ### Code 8.20 # In[21]: d["blooms_std"] = d.blooms / d.blooms.max() d["water_cent"] = d.water - d.water.mean() d["shade_cent"] = d.shade - d.shade.mean() # ### Code 8.21 # In[22]: a = dist.Normal(0.5, 1).sample(random.PRNGKey(0), (int(1e4),)) jnp.sum((a < 0) | (a > 1)) / a.shape[0] # ### Code 8.22 # In[23]: a = dist.Normal(0.5, 0.25).sample(random.PRNGKey(0), (int(1e4),)) jnp.sum((a < 0) | (a > 1)) / a.shape[0] # ### Code 8.23 # In[24]: def model(water_cent, shade_cent, blooms_std=None): a = numpyro.sample("a", dist.Normal(0.5, 0.25)) bw = numpyro.sample("bw", dist.Normal(0, 0.25)) bs = numpyro.sample("bs", dist.Normal(0, 0.25)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a + bw * water_cent + bs * shade_cent) numpyro.sample("blooms_std", dist.Normal(mu, sigma), obs=blooms_std) m8_4 = AutoLaplaceApproximation(model) svi = SVI( model, m8_4, optim.Adam(1), Trace_ELBO(), shade_cent=d.shade_cent.values, water_cent=d.water_cent.values, blooms_std=d.blooms_std.values, ) svi_result = svi.run(random.PRNGKey(0), 1000) p8_4 = svi_result.params # ### Code 8.24 # In[25]: def model(water_cent, shade_cent, blooms_std=None): a = numpyro.sample("a", dist.Normal(0.5, 0.25)) bw = numpyro.sample("bw", dist.Normal(0, 0.25)) bs = numpyro.sample("bs", dist.Normal(0, 0.25)) bws = numpyro.sample("bws", dist.Normal(0, 0.25)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = a + bw * water_cent + bs * shade_cent + bws * water_cent * shade_cent numpyro.sample("blooms_std", dist.Normal(mu, sigma), obs=blooms_std) m8_5 = AutoLaplaceApproximation(model) svi = SVI( model, m8_5, optim.Adam(1), Trace_ELBO(), shade_cent=d.shade_cent.values, water_cent=d.water_cent.values, blooms_std=d.blooms_std.values, ) svi_result = svi.run(random.PRNGKey(0), 1000) p8_5 = svi_result.params # ### Code 8.25 # In[26]: _, axes = plt.subplots(1, 3, figsize=(9, 3), sharey=True) # 3 plots in 1 row for ax, s in zip(axes, range(-1, 2)): idx = d.shade_cent == s ax.scatter(d.water_cent[idx], d.blooms_std[idx]) ax.set(xlim=(-1.1, 1.1), ylim=(-0.1, 1.1), xlabel="water", ylabel="blooms") post = m8_4.sample_posterior(random.PRNGKey(1), p8_4, sample_shape=(1000,)) post.pop("mu") mu = Predictive(m8_4.model, post, return_sites=["mu"])( random.PRNGKey(2), shade_cent=s, water_cent=jnp.arange(-1, 2) )["mu"] for i in range(20): ax.plot(range(-1, 2), mu[i], "k", alpha=0.3) # ### Code 8.26 # In[27]: predictive = Predictive( m8_5.model, num_samples=1000, return_sites=["a", "bw", "bs", "bws", "sigma"] ) prior = predictive(random.PRNGKey(7), water_cent=0, shade_cent=0) # ### Code 8.27 # In[28]: nettle = pd.read_csv("../data/nettle.csv", sep=";") d = nettle d["lang.per.cap"] = d["num.lang"] / d["k.pop"]