!pip install -q numpyro arviz
import math
import os
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import jax.numpy as jnp
from jax import ops, random, vmap
from jax.scipy.special import expit
import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import print_summary
from numpyro.distributions import constraints
from numpyro.infer import MCMC, NUTS, init_to_value
if "SVG" in os.environ:
%config InlineBackend.figure_formats = ["svg"]
az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)
# simulate a pancake and return randomly ordered sides
def sim_pancake(seed):
pancake = dist.Categorical(logits=jnp.ones(3)).sample(random.PRNGKey(2 * seed))
sides = jnp.array([1, 1, 1, 0, 0, 0]).reshape(3, 2).T[:, pancake]
return random.permutation(random.PRNGKey(2 * seed + 1), sides)
# sim 10,000 pancakes
pancakes = vmap(sim_pancake, out_axes=1)(jnp.arange(10000))
up = pancakes[0]
down = pancakes[1]
# compute proportion 1/1 (BB) out of all 1/1 and 1/0
num_11_10 = jnp.sum(up == 1)
num_11 = jnp.sum((up == 1) & (down == 1))
num_11 / num_11_10
DeviceArray(0.6641716, dtype=float32)
WaffleDivorce = pd.read_csv("../data/WaffleDivorce.csv", sep=";")
d = WaffleDivorce
# points
ax = az.plot_pair(
d[["MedianAgeMarriage", "Divorce"]].to_dict(orient="list"),
scatter_kwargs=dict(s=15, facecolors="none"),
)
ax.set(ylim=(4, 15), xlabel="Median age marrage", ylabel="Divorce rate")
# standard errors
for i in range(d.shape[0]):
ci = d.Divorce[i] + jnp.array([-1, 1]) * d["Divorce SE"][i]
x = d.MedianAgeMarriage[i]
plt.plot([x, x], ci, "k")
dlist = dict(
D_obs=d.Divorce.pipe(lambda x: (x - x.mean()) / x.std()).values,
D_sd=d["Divorce SE"].values / d.Divorce.std(),
M=d.Marriage.pipe(lambda x: (x - x.mean()) / x.std()).values,
A=d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std()).values,
N=d.shape[0],
)
def model(A, M, D_sd, D_obs, N):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + bA * A + bM * M
D_true = numpyro.sample("D_true", dist.Normal(mu, sigma))
numpyro.sample("D_obs", dist.Normal(D_true, D_sd), obs=D_obs)
m15_1 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_1.run(random.PRNGKey(0), **dlist)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
m15_1.print_summary(0.89)
mean std median 5.5% 94.5% n_eff r_hat D_true[0] 1.16 0.36 1.17 0.61 1.75 3190.26 1.00 D_true[1] 0.70 0.52 0.69 -0.14 1.50 3050.85 1.00 D_true[2] 0.42 0.32 0.42 -0.05 0.97 4423.54 1.00 D_true[3] 1.42 0.47 1.41 0.62 2.13 4693.23 1.00 D_true[4] -0.90 0.13 -0.90 -1.09 -0.68 4757.64 1.00 D_true[5] 0.65 0.38 0.64 0.04 1.27 3480.80 1.00 D_true[6] -1.36 0.35 -1.36 -1.89 -0.77 3693.32 1.00 D_true[7] -0.32 0.48 -0.32 -1.08 0.45 3808.76 1.00 D_true[8] -1.88 0.60 -1.88 -2.85 -0.93 3263.41 1.00 D_true[9] -0.62 0.17 -0.62 -0.87 -0.34 5084.32 1.00 D_true[10] 0.76 0.28 0.76 0.32 1.19 4423.47 1.00 D_true[11] -0.54 0.49 -0.54 -1.33 0.18 3491.72 1.00 D_true[12] 0.19 0.49 0.20 -0.59 0.99 2026.07 1.00 D_true[13] -0.87 0.22 -0.88 -1.23 -0.55 5782.09 1.00 D_true[14] 0.56 0.31 0.55 0.11 1.12 4043.40 1.00 D_true[15] 0.28 0.37 0.28 -0.34 0.83 4969.48 1.00 D_true[16] 0.49 0.39 0.49 -0.11 1.14 3902.02 1.00 D_true[17] 1.25 0.34 1.25 0.68 1.77 4268.92 1.00 D_true[18] 0.43 0.39 0.42 -0.20 1.04 3958.69 1.00 D_true[19] 0.39 0.53 0.38 -0.42 1.25 2242.96 1.00 D_true[20] -0.55 0.33 -0.56 -1.06 -0.03 4700.87 1.00 D_true[21] -1.11 0.25 -1.10 -1.49 -0.70 3459.65 1.00 D_true[22] -0.27 0.25 -0.27 -0.66 0.12 3367.35 1.00 D_true[23] -1.00 0.30 -1.00 -1.47 -0.54 3468.74 1.00 D_true[24] 0.43 0.41 0.43 -0.23 1.06 4192.29 1.00 D_true[25] -0.03 0.31 -0.02 -0.51 0.49 4341.29 1.00 D_true[26] -0.03 0.49 -0.02 -0.81 0.71 4395.28 1.00 D_true[27] -0.14 0.38 -0.15 -0.77 0.43 4835.72 1.00 D_true[28] -0.25 0.49 -0.27 -0.99 0.56 3569.33 1.00 D_true[29] -1.79 0.24 -1.80 -2.16 -1.42 4291.70 1.00 D_true[30] 0.18 0.40 0.17 -0.43 0.83 4206.43 1.00 D_true[31] -1.66 0.16 -1.66 -1.92 -1.40 5105.78 1.00 D_true[32] 0.12 0.23 0.12 -0.24 0.48 5028.23 1.00 D_true[33] -0.05 0.53 -0.04 -0.87 0.77 3471.07 1.00 D_true[34] -0.12 0.23 -0.12 -0.48 0.23 4786.81 1.00 D_true[35] 1.28 0.40 1.28 0.61 1.89 3415.97 1.00 D_true[36] 0.23 0.34 0.24 -0.34 0.75 4679.23 1.00 D_true[37] -1.03 0.22 -1.02 -1.36 -0.65 4103.00 1.00 D_true[38] -0.91 0.54 -0.92 -1.84 -0.13 3537.37 1.00 D_true[39] -0.67 0.31 -0.67 -1.16 -0.18 3980.07 1.00 D_true[40] 0.24 0.56 0.24 -0.60 1.16 4535.89 1.00 D_true[41] 0.74 0.33 0.74 0.20 1.23 4946.62 1.00 D_true[42] 0.20 0.18 0.20 -0.09 0.47 4276.96 1.00 D_true[43] 0.81 0.44 0.82 0.12 1.49 2592.55 1.00 D_true[44] -0.42 0.53 -0.43 -1.23 0.44 3480.32 1.00 D_true[45] -0.39 0.25 -0.39 -0.78 0.01 5468.52 1.00 D_true[46] 0.14 0.29 0.14 -0.36 0.58 4116.98 1.00 D_true[47] 0.56 0.46 0.56 -0.13 1.33 3906.03 1.00 D_true[48] -0.63 0.28 -0.64 -1.05 -0.16 4596.77 1.00 D_true[49] 0.86 0.59 0.86 -0.05 1.84 3001.88 1.00 a -0.05 0.09 -0.06 -0.19 0.10 3053.48 1.00 bA -0.61 0.16 -0.62 -0.84 -0.34 2406.74 1.00 bM 0.05 0.17 0.05 -0.23 0.30 2181.48 1.00 sigma 0.58 0.10 0.57 0.41 0.74 832.29 1.00 Number of divergences: 0
dlist = dict(
D_obs=d.Divorce.pipe(lambda x: (x - x.mean()) / x.std()).values,
D_sd=d["Divorce SE"].values / d.Divorce.std(),
M_obs=d.Marriage.pipe(lambda x: (x - x.mean()) / x.std()).values,
M_sd=d["Marriage SE"].values / d.Marriage.std(),
A=d.MedianAgeMarriage.pipe(lambda x: (x - x.mean()) / x.std()).values,
N=d.shape[0],
)
def model(A, M_sd, M_obs, D_sd, D_obs, N):
a = numpyro.sample("a", dist.Normal(0, 0.2))
bA = numpyro.sample("bA", dist.Normal(0, 0.5))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
M_est = numpyro.sample("M_est", dist.Normal(0, 1).expand([N]))
numpyro.sample("M_obs", dist.Normal(M_est, M_sd), obs=M_obs)
mu = a + bA * A + bM * M_est
D_est = numpyro.sample("D_est", dist.Normal(mu, sigma))
numpyro.sample("D_obs", dist.Normal(D_est, D_sd), obs=D_obs)
m15_2 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_2.run(random.PRNGKey(0), **dlist)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
post = m15_2.get_samples()
D_est = jnp.mean(post["D_est"], 0)
M_est = jnp.mean(post["M_est"], 0)
plt.plot(dlist["M_obs"], dlist["D_obs"], "bo", alpha=0.5)
plt.gca().set(xlabel="marriage rate (std)", ylabel="divorce rate (std)")
plt.plot(M_est, D_est, "ko", mfc="none")
for i in range(d.shape[0]):
plt.plot([dlist["M_obs"][i], M_est[i]], [dlist["D_obs"][i], D_est[i]], "k-", lw=1)
N = 500
A = dist.Normal().sample(random.PRNGKey(0), (N,))
M = dist.Normal(-A).sample(random.PRNGKey(1))
D = dist.Normal(A).sample(random.PRNGKey(2))
A_obs = dist.Normal(A).sample(random.PRNGKey(3))
N = 100
S = dist.Normal().sample(random.PRNGKey(0), (N,))
H = dist.Binomial(10, expit(S)).sample(random.PRNGKey(1))
D = dist.Bernoulli(0.5).sample(random.PRNGKey(2), (N,)) # dogs completely random
Hm = jnp.where(D == 1, jnp.nan, H)
D = jnp.where(S > 0, 1, 0)
Hm = jnp.where(D == 1, jnp.nan, H)
with numpyro.handlers.seed(rng_seed=501):
N = 1000
X = numpyro.sample("X", dist.Normal().expand([N]))
S = numpyro.sample("S", dist.Normal().expand([N]))
H = numpyro.sample("H", dist.Binomial(10, logits=2 + S - 2 * X))
D = jnp.where(X > 1, 1, 0)
Hm = jnp.where(D == 1, jnp.nan, H)
dat_list = dict(H=H, S=S)
def model(S, H):
a = numpyro.sample("a", dist.Normal(0, 1))
bS = numpyro.sample("bS", dist.Normal(0, 0.5))
logit_p = a + bS * S
numpyro.sample("H", dist.Binomial(10, logits=logit_p), obs=H)
m15_3 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_3.run(random.PRNGKey(0), **dat_list)
m15_3.print_summary()
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
mean std median 5.0% 95.0% n_eff r_hat a 1.32 0.03 1.32 1.28 1.36 1240.12 1.00 bS 0.62 0.03 0.62 0.58 0.67 1298.89 1.00 Number of divergences: 0
dat_list0 = dict(H=H[D == 0], S=S[D == 0])
def model(S, H):
a = numpyro.sample("a", dist.Normal(0, 1))
bS = numpyro.sample("bS", dist.Normal(0, 0.5))
logit_p = a + bS * S
numpyro.sample("H", dist.Binomial(10, logits=logit_p), obs=H)
m15_4 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_4.run(random.PRNGKey(0), **dat_list0)
m15_4.print_summary()
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
mean std median 5.0% 95.0% n_eff r_hat a 1.92 0.03 1.92 1.86 1.97 1023.26 1.00 bS 0.73 0.03 0.73 0.67 0.78 989.63 1.00 Number of divergences: 0
D = jnp.where(jnp.abs(X) < 1, 1, 0)
N = 100
S = dist.Normal().sample(random.PRNGKey(0), (N,))
H = dist.Binomial(10, logits=S).sample(random.PRNGKey(1))
D = jnp.where(H < 5, 1, 0)
Hm = jnp.where(D == 1, jnp.nan, H)
milk = pd.read_csv("../data/milk.csv", sep=";")
d = milk
d["neocortex.prop"] = d["neocortex.perc"] / 100
d["logmass"] = d.mass.apply(math.log)
dat_list = dict(
K=d["kcal.per.g"].pipe(lambda x: (x - x.mean()) / x.std()).values,
B=d["neocortex.prop"].pipe(lambda x: (x - x.mean()) / x.std()).values,
M=d.logmass.pipe(lambda x: (x - x.mean()) / x.std()).values,
)
def model(B, M, K):
a = numpyro.sample("a", dist.Normal(0, 0.5))
nu = numpyro.sample("nu", dist.Normal(0, 0.5))
bB = numpyro.sample("bB", dist.Normal(0, 0.5))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma_B = numpyro.sample("sigma_B", dist.Exponential(1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
B_impute = numpyro.sample(
"B_impute", dist.Normal(0, 1).expand([int(np.isnan(B).sum())]).mask(False)
)
B = jnp.asarray(B).at[np.nonzero(np.isnan(B))[0]].set(B_impute)
numpyro.sample("B", dist.Normal(nu, sigma_B), obs=B)
mu = a + bB * B + bM * M
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m15_5 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_5.run(random.PRNGKey(0), **dat_list)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
m15_5.print_summary(0.89)
mean std median 5.5% 94.5% n_eff r_hat B_impute[0] -0.50 0.91 -0.54 -1.90 0.95 1918.23 1.00 B_impute[1] -0.64 0.92 -0.65 -2.15 0.74 1424.06 1.00 B_impute[2] -0.68 0.94 -0.71 -2.27 0.78 1338.37 1.00 B_impute[3] -0.26 0.88 -0.27 -1.65 1.09 2203.85 1.00 B_impute[4] 0.43 0.88 0.42 -0.87 1.95 1664.89 1.00 B_impute[5] -0.15 0.90 -0.19 -1.57 1.29 1910.72 1.00 B_impute[6] 0.17 0.92 0.18 -1.30 1.62 2257.71 1.00 B_impute[7] 0.24 0.87 0.25 -0.99 1.80 2024.09 1.00 B_impute[8] 0.49 0.92 0.49 -0.87 2.05 1943.39 1.00 B_impute[9] -0.42 0.93 -0.42 -1.93 0.95 1966.23 1.00 B_impute[10] -0.26 0.92 -0.27 -1.70 1.21 1980.72 1.00 B_impute[11] 0.15 0.95 0.17 -1.36 1.64 1693.70 1.00 a 0.02 0.17 0.03 -0.24 0.29 1930.84 1.00 bB 0.49 0.24 0.50 0.09 0.86 775.56 1.00 bM -0.54 0.21 -0.55 -0.86 -0.20 890.11 1.00 nu -0.04 0.21 -0.04 -0.36 0.29 1704.73 1.00 sigma 0.85 0.14 0.84 0.64 1.07 1021.42 1.00 sigma_B 1.01 0.17 0.99 0.75 1.26 1246.13 1.00 Number of divergences: 0
obs_idx = d["neocortex.prop"].notnull().values
dat_list_obs = dict(
K=dat_list["K"][obs_idx], B=dat_list["B"][obs_idx], M=dat_list["M"][obs_idx]
)
def model(B, M, K):
a = numpyro.sample("a", dist.Normal(0, 0.5))
nu = numpyro.sample("nu", dist.Normal(0, 0.5))
bB = numpyro.sample("bB", dist.Normal(0, 0.5))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma_B = numpyro.sample("sigma_B", dist.Exponential(1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
numpyro.sample("B", dist.Normal(nu, sigma_B), obs=B)
mu = a + bB * B + bM * M
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m15_6 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_6.run(random.PRNGKey(0), **dat_list_obs)
m15_6.print_summary(0.89)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
mean std median 5.5% 94.5% n_eff r_hat a 0.10 0.19 0.11 -0.18 0.42 1803.62 1.00 bB 0.60 0.28 0.61 0.17 1.04 1426.34 1.00 bM -0.63 0.25 -0.65 -1.06 -0.28 1259.25 1.00 nu -0.00 0.23 -0.00 -0.34 0.37 2292.07 1.00 sigma 0.87 0.18 0.85 0.59 1.12 1300.49 1.00 sigma_B 1.03 0.18 1.01 0.77 1.31 1956.68 1.00 Number of divergences: 0
az.plot_forest(
[az.from_numpyro(m15_5), az.from_numpyro(m15_6)],
model_names=["m15.5", "m15.6"],
var_names=["bB", "bM"],
combined=True,
hdi_prob=0.89,
)
plt.show()
post = m15_5.get_samples()
B_impute_mu = jnp.mean(post["B_impute"], 0)
B_impute_ci = jnp.percentile(post["B_impute"], q=jnp.array([5.5, 94.5]), axis=0)
# B vs K
plt.plot(dat_list["B"], dat_list["K"], "o")
plt.gca().set(xlabel="neocortex percent (std)", ylabel="kcal mild (std)")
miss_idx = pd.isna(dat_list["B"]).nonzero()[0]
Ki = dat_list["K"][miss_idx]
plt.plot(B_impute_mu, Ki, "ko", mfc="none")
for i in range(12):
plt.plot(B_impute_ci[:, i], jnp.repeat(Ki[i], 2), "k", lw=1)
plt.show()
# M vs B
plt.plot(dat_list["M"], dat_list["B"], "o")
plt.gca().set(xlabel="log body mass (std)", ylabel="neocortex percent (std)")
Mi = dat_list["M"][miss_idx]
plt.plot(Mi, B_impute_mu, "ko", mfc="none")
for i in range(12):
plt.plot(jnp.repeat(Mi[i], 2), B_impute_ci[:, i], "k", lw=1)
def model(B, M, K):
# priors
a = numpyro.sample("a", dist.Normal(0, 0.5))
muB = numpyro.sample("muB", dist.Normal(0, 0.5))
muM = numpyro.sample("muM", dist.Normal(0, 0.5))
bB = numpyro.sample("bB", dist.Normal(0, 0.5))
bM = numpyro.sample("bM", dist.Normal(0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1))
Rho_BM = numpyro.sample("Rho_BM", dist.LKJ(2, 2))
Sigma_BM = numpyro.sample("Sigma_BM", dist.Exponential(1).expand([2]))
# define B_merge as mix of observed and imputed values
B_impute = numpyro.sample(
"B_impute", dist.Normal(0, 1).expand([int(np.isnan(B).sum())]).mask(False)
)
B_merge = jnp.asarray(B).at[np.nonzero(np.isnan(B))[0]].set(B_impute)
# M and B correlation
MB = jnp.stack([M, B_merge], axis=1)
cov = jnp.outer(Sigma_BM, Sigma_BM) * Rho_BM
numpyro.sample("MB", dist.MultivariateNormal(jnp.stack([muM, muB]), cov), obs=MB)
# K as function of B and M
mu = a + bB * B_merge + bM * M
numpyro.sample("K", dist.Normal(mu, sigma), obs=K)
m15_7 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_7.run(random.PRNGKey(0), **dat_list)
post = m15_7.get_samples(group_by_chain=True)
print_summary({k: v for k, v in post.items() if k in ["bM", "bB", "Rho_BM"]})
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
mean std median 5.0% 95.0% n_eff r_hat Rho_BM[0,0] 1.00 0.00 1.00 1.00 1.00 nan nan Rho_BM[0,1] 0.61 0.13 0.62 0.39 0.80 1387.58 1.00 Rho_BM[1,0] 0.61 0.13 0.62 0.39 0.80 1387.58 1.00 Rho_BM[1,1] 1.00 0.00 1.00 1.00 1.00 297.05 1.00 bB 0.58 0.26 0.60 0.15 0.99 766.01 1.01 bM -0.64 0.22 -0.65 -1.01 -0.30 860.13 1.00
B_missidx = pd.isna(dat_list["B"]).nonzero()[0]
Moralizing_gods = pd.read_csv("../data/Moralizing_gods.csv", sep=";")
Moralizing_gods
polity | year | population | moralizing_gods | writing | |
---|---|---|---|---|---|
0 | Big Island Hawaii | 1000 | 3.729643 | NaN | 0 |
1 | Big Island Hawaii | 1100 | 3.729643 | NaN | 0 |
2 | Big Island Hawaii | 1200 | 3.598340 | NaN | 0 |
3 | Big Island Hawaii | 1300 | 4.026240 | NaN | 0 |
4 | Big Island Hawaii | 1400 | 4.311767 | NaN | 0 |
... | ... | ... | ... | ... | ... |
859 | Yemeni Coastal Plain | 1400 | 6.763083 | 1.0 | 1 |
860 | Yemeni Coastal Plain | 1500 | 6.519621 | 1.0 | 1 |
861 | Konya Plain | 1600 | 7.447158 | 1.0 | 1 |
862 | Yemeni Coastal Plain | 1700 | 3.882606 | 1.0 | 1 |
863 | Yemeni Coastal Plain | 1800 | 3.882606 | 1.0 | 1 |
864 rows × 5 columns
Moralizing_gods.moralizing_gods.value_counts(dropna=False)
NaN 528 1.0 319 0.0 17 Name: moralizing_gods, dtype: int64
symbol = Moralizing_gods.moralizing_gods.apply(lambda x: "." if x == 1 else "o")
symbol[Moralizing_gods.moralizing_gods.isna()] = "x"
color = Moralizing_gods.moralizing_gods.apply(lambda x: "k" if pd.isna(x) else "b")
for pch in ["o", ".", "x"]:
plt.scatter(
Moralizing_gods.year[symbol == pch],
Moralizing_gods.population[symbol == pch],
marker=pch,
color=color[symbol == pch],
facecolor="none" if pch == "o" else None,
lw=1.5,
alpha=0.7,
)
plt.gca().set(xlabel="Time (year)", ylabel="Population size")
plt.show()
dmg = Moralizing_gods
dmg.astype(str).groupby(["moralizing_gods", "writing"]).size().unstack(fill_value=0)
writing | 0 | 1 |
---|---|---|
moralizing_gods | ||
0.0 | 16 | 1 |
1.0 | 9 | 310 |
nan | 442 | 86 |
dmg = Moralizing_gods
haw = dmg.polity == "Big Island Hawaii"
dmg.loc[haw, ["year", "population", "writing", "moralizing_gods"]].T.round(3)
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | |
---|---|---|---|---|---|---|---|---|---|
year | 1000.00 | 1100.00 | 1200.000 | 1300.000 | 1400.000 | 1500.000 | 1600.000 | 1700.000 | 1800.000 |
population | 3.73 | 3.73 | 3.598 | 4.026 | 4.312 | 4.205 | 4.374 | 5.158 | 4.997 |
writing | 0.00 | 0.00 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 |
moralizing_gods | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 1.000 |
with numpyro.handlers.seed(rng_seed=9):
N_houses = 100
alpha = 5
beta = -3
k = 0.5
r = 0.2
cat = numpyro.sample("cat", dist.Bernoulli(k).expand([N_houses]))
notes = numpyro.sample("notes", dist.Poisson(alpha + beta * cat))
R_C = numpyro.sample("R_C", dist.Bernoulli(r).expand([N_houses]))
cat_obs = jnp.where(R_C == 1, -9, cat)
dat = dict(notes=notes, cat=np.asarray(cat_obs), RC=np.asarray(R_C), N=N_houses - 1)
def model(N, RC, cat, notes):
# priors
a = numpyro.sample("a", dist.Normal(0, 1))
b = numpyro.sample("b", dist.Normal(0, 0.5))
# sneaking cat model
k = numpyro.sample("k", dist.Beta(2, 2))
numpyro.sample("cat|RC==0", dist.Bernoulli(k), obs=cat[RC == 0])
# singing bird model
# cat NA:
custom_logprob = jnp.logaddexp(
jnp.log(k) + dist.Poisson(jnp.exp(a + b)).log_prob(notes[RC == 1]),
jnp.log(1 - k) + dist.Poisson(jnp.exp(a)).log_prob(notes[RC == 1]),
)
numpyro.factor("notes|RC==1", custom_logprob)
# cat known present/absent:
lambda_ = jnp.exp(a + b * cat[RC == 0])
numpyro.sample("notes|RC==0", dist.Poisson(lambda_), obs=notes[RC == 0])
m15_8 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_8.run(random.PRNGKey(0), **dat)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
def model(N, RC, cat, notes, link=False):
a = numpyro.sample("a", dist.Normal(0, 1))
b = numpyro.sample("b", dist.Normal(0, 0.5))
# sneaking cat model
k = numpyro.sample("k", dist.Beta(2, 2))
numpyro.sample("cat|RC==0", dist.Bernoulli(k), obs=cat[RC == 0])
# singing bird model
custom_logprob = jnp.logaddexp(
jnp.log(k) + dist.Poisson(jnp.exp(a + b)).log_prob(notes[RC == 1]),
jnp.log(1 - k) + dist.Poisson(jnp.exp(a)).log_prob(notes[RC == 1]),
)
numpyro.factor("notes|RC==1", custom_logprob)
lambda_ = jnp.exp(a + b * cat[RC == 0])
numpyro.sample("notes|RC==0", dist.Poisson(lambda_), obs=notes[RC == 0])
if link:
lpC0 = numpyro.deterministic(
"lpC0", jnp.log(1 - k) + dist.Poisson(jnp.exp(a)).log_prob(notes)
)
lpC1 = numpyro.deterministic(
"lpC1", jnp.log(k) + dist.Poisson(jnp.exp(a + b)).log_prob(notes)
)
numpyro.deterministic("PrC1", jnp.exp(lpC1) / (jnp.exp(lpC1) + jnp.exp(lpC0)))
m15_9 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m15_9.run(random.PRNGKey(0), **dat)
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 0/1000 [00:00<?, ?it/s]
with numpyro.handlers.seed(rng_seed=100):
x = numpyro.sample("x", dist.Normal().expand([10]))
y = numpyro.sample("y", dist.Normal(x))
x = jnp.concatenate([x, jnp.array([jnp.nan])])
y = jnp.concatenate([y, jnp.array([100])])
d = dict(x=x, y=y)
Primates301 = pd.read_csv("../data/Primates301.csv", sep=";")
d = Primates301
cc = d.dropna(subset=["brain", "body"]).index
B = d.brain[cc]
M = d.body[cc]
B = B.values / max(B)
M = M.values / max(M)
Bse = B * 0.1
Mse = M * 0.1
dat_list = dict(B=B, M=M)
def model(M, B):
a = numpyro.sample("a", dist.Normal(0, 1))
b = numpyro.sample("b", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1))
mu = a + b * jnp.log(M)
numpyro.sample("B", dist.LogNormal(mu, sigma), obs=B)
m15H4 = MCMC(NUTS(model), num_warmup=500, num_samples=500)
m15H4.run(random.PRNGKey(0), **dat_list)
sample: 100%|██████████| 1000/1000 [00:03<00:00, 285.81it/s, 3 steps of size 2.54e-01. acc. prob=0.93]
start = dict(M_true=dat_list["M"], B_true=dat_list["B"])
init_strategy = init_to_value(values=start)
Primates301 = pd.read_csv("../data/Primates301.csv", sep=";")
d = Primates301
d.isna().sum()
name 0 genus 0 species 0 subspecies 267 spp_id 0 genus_id 0 social_learning 98 research_effort 115 brain 117 body 63 group_size 114 gestation 161 weaning 185 longevity 181 sex_maturity 194 maternal_investment 197 dtype: int64
cc = d.dropna(subset=["body"]).index
M = d.body[cc]
M = M.values / max(M)
B = d.brain[cc]
B = B.values / B.max(skipna=True)
start = dict(B_impute=jnp.repeat(0.5, 56))
init_strategy = init_to_value(values=start)