import math
import pandas as pd
import seaborn as sns
import torch
from torch.distributions import transforms
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS
from rethinking import (LM, MAP, coef, extract_samples, glimmer,
link, precis, replicate, sim, vcov)
num_weeks = int(1e5)
positions = torch.tensor(0).repeat(num_weeks)
current = 9
for i in range(num_weeks):
# record current position
positions[i] = current
# flip coin to generate proposal
sample = torch.multinomial(torch.ones(2), num_samples=1)
proposal = current + torch.tensor([-1, 1])[sample].item()
# now make sure he loops around the archipelago
if proposal < 0:
proposal = 9
if proposal > 9:
proposal = 0
# move?
prob_move = (proposal + 1) / (current + 1)
current = proposal if torch.rand(1) < prob_move else current
rugged = pd.read_csv("../data/rugged.csv", sep=";")
d = rugged
d["log_gdp"] = d["rgdppc_2000"].apply(math.log)
dd = d[d["rgdppc_2000"].notnull()]
def model(rugged, cont_africa, log_gdp):
a = pyro.sample("a", dist.Normal(0, 100))
bR = pyro.sample("bR", dist.Normal(0, 10))
bA = pyro.sample("bA", dist.Normal(0, 10))
bAR = pyro.sample("bAR", dist.Normal(0, 10))
mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
sigma = pyro.sample("sigma", dist.Uniform(0, 10))
with pyro.plate("plate"):
pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)
dd.index = range(dd.shape[0])
dd_rugged = torch.tensor(dd["rugged"], dtype=torch.float)
dd_cont_africa = torch.tensor(dd["cont_africa"], dtype=torch.float)
dd_log_gdp = torch.tensor(dd["log_gdp"], dtype=torch.float)
m8_1 = MAP(model).run(dd_rugged, dd_cont_africa, dd_log_gdp)
precis(m8_1)
Mean | StdDev | |0.89 | 0.89| | |
---|---|---|---|---|
a | 9.23 | 0.14 | 9.01 | 9.46 |
bR | -0.20 | 0.08 | -0.33 | -0.09 |
bA | -1.95 | 0.22 | -2.30 | -1.59 |
bAR | 0.40 | 0.13 | 0.20 | 0.61 |
sigma | 0.94 | 0.05 | 0.85 | 1.01 |
dd_trim = dd[["log_gdp", "rugged", "cont_africa"]]
dd_trim.info()
dd_trim.head()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 170 entries, 0 to 169 Data columns (total 3 columns): log_gdp 170 non-null float64 rugged 170 non-null float64 cont_africa 170 non-null int64 dtypes: float64(2), int64(1) memory usage: 4.1 KB
log_gdp | rugged | cont_africa | |
---|---|---|---|
0 | 7.492609 | 0.858 | 1 |
1 | 8.216929 | 3.427 | 0 |
2 | 9.933263 | 0.769 | 0 |
3 | 9.407032 | 0.775 | 0 |
4 | 7.792343 | 2.688 | 0 |
def model(rugged, cont_africa, log_gdp):
a = pyro.sample("a", dist.Normal(0, 100))
bR = pyro.sample("bR", dist.Normal(0, 10))
bA = pyro.sample("bA", dist.Normal(0, 10))
bAR = pyro.sample("bAR", dist.Normal(0, 10))
mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
sigma = pyro.sample("sigma", dist.Uniform(0, 10))
with pyro.plate("plate"):
pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)
kernel = NUTS(model)
m8_1stan = MCMC(kernel, num_samples=1000).run(dd_rugged, dd_cont_africa, dd_log_gdp)
HBox(children=(IntProgress(value=0, description='Warmup', max=2000, style=ProgressStyle(description_width='ini…
precis(m8_1stan)
Mean | StdDev | |0.89 | 0.89| | n_eff | r_hat | |
---|---|---|---|---|---|---|
a | 9.21 | 0.14 | 8.99 | 9.44 | 231.13 | 1.0 |
bR | -0.20 | 0.08 | -0.33 | -0.07 | 237.89 | 1.0 |
bA | -1.93 | 0.23 | -2.27 | -1.54 | 239.02 | 1.0 |
bAR | 0.38 | 0.13 | 0.17 | 0.56 | 283.33 | 1.0 |
sigma | 0.95 | 0.05 | 0.86 | 1.03 | 468.30 | 1.0 |
def model(rugged, cont_africa, log_gdp):
a = pyro.sample("a", dist.Normal(0, 100))
bR = pyro.sample("bR", dist.Normal(0, 10))
bA = pyro.sample("bA", dist.Normal(0, 10))
bAR = pyro.sample("bAR", dist.Normal(0, 10))
mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
sigma = pyro.sample("sigma", dist.Uniform(0, 10))
with pyro.plate("plate"):
pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)
kernel = NUTS(model)
m8_1stan_4chains = MCMC(kernel, num_samples=1000, num_chains=4)
m8_1stan_4chains.run(dd_rugged, dd_cont_africa, dd_log_gdp)
precis(m8_1stan_4chains)
HBox(children=(IntProgress(value=0, description='Warmup [1]', max=2000, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=0, description='Warmup [2]', max=2000, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=0, description='Warmup [3]', max=2000, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=0, description='Warmup [4]', max=2000, style=ProgressStyle(description_width=…
Mean | StdDev | |0.89 | 0.89| | n_eff | r_hat | |
---|---|---|---|---|---|---|
a | 9.22 | 0.15 | 9.01 | 9.47 | 1063.48 | 1.0 |
bR | -0.20 | 0.08 | -0.34 | -0.08 | 1080.07 | 1.0 |
bA | -1.95 | 0.24 | -2.35 | -1.60 | 1060.95 | 1.0 |
bAR | 0.39 | 0.14 | 0.17 | 0.61 | 1179.82 | 1.0 |
sigma | 0.95 | 0.05 | 0.87 | 1.03 | 1778.89 | 1.0 |
post = extract_samples(m8_1stan)
{latent: post[latent][:5] for latent in post}
{'a': tensor([9.2374, 9.3048, 9.1149, 9.0362, 9.3176]), 'bR': tensor([-0.2022, -0.3170, -0.1933, -0.1671, -0.1185]), 'bA': tensor([-2.1425, -2.0199, -1.9134, -1.8863, -2.1787]), 'bAR': tensor([0.4066, 0.4874, 0.3268, 0.4041, 0.3728]), 'sigma': tensor([0.8741, 0.9995, 1.0940, 1.0770, 0.9178])}
sns.pairplot(pd.DataFrame(post), diag_kind="kde",
plot_kws={"edgecolor": "none", "alpha": 0.2});
post = extract_samples(m8_1stan)
sns.pairplot(pd.DataFrame(post), diag_kind="kde",
plot_kws={"edgecolor": "none", "alpha": 0.2});
m8_1stan.information_criterion()
OrderedDict([('waic', tensor(469.6760)), ('p_waic', tensor(5.3495))])
post = extract_samples(m8_1stan)
precis_df = precis(m8_1stan)
for latent in post:
ax = sns.lineplot(range(1000), post[latent])
ax.set(title="{} | n_eff = {}"
.format(latent, int(precis_df.loc[latent, "n_eff"])))
sns.mpl.pyplot.show()
y = torch.tensor([-1., 1])
def model(y):
mu = pyro.sample("alpha", dist.Normal(0, 1000000))
sigma = pyro.sample("sigma", dist.HalfCauchy(200000))
with pyro.plate("plate"):
pyro.sample("y", dist.Normal(mu, sigma), obs=y)
kernel = NUTS(model, jit_compile=True)
trace = pyro.poutine.trace(model).get_trace(y)
trace.nodes["alpha"]["value"] = torch.tensor(0.)
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m8_2 = MCMC(kernel, num_samples=50, warmup_steps=100, num_chains=2).run(y)
precis(m8_2)
def model(y):
alpha = pyro.sample("alpha", dist.Normal(1, 10))
mu = alpha
sigma = pyro.sample("sigma", dist.HalfCauchy(1))
with pyro.plate("plate"):
pyro.sample("y", dist.Normal(mu, sigma), obs=y)
kernel = NUTS(model)
trace = pyro.poutine.trace(model).get_trace(y)
trace.nodes["alpha"]["value"] = torch.tensor(0.)
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m8_3 = MCMC(kernel, num_samples=3000, warmup_steps=1000, num_chains=2).run(y)
precis(m8_3)
y = dist.Cauchy(0, 5).sample(torch.Size([int(1e4)]))
mu = [sum(y[:i]) / (i + 1) for i in range(y.size(0))]
sns.pointplot(mu)
y = torch.randn(100)
def model(y):
a1 = pyro.sample("a1", dist.Normal(0, 1000000))
a2 = pyro.sample("a2", dist.Normal(0, 1000000))
mu = a1 + a2
sigma = pyro.sample("sigma", dist.HalfCauchy(1))
with pyro.plate("plate"):
pyro.sample("y", dist.Normal(mu, sigma), obs=y)
kernel = NUTS(model, jit_compile=True, max_tree_depth=7)
trace = pyro.poutine.trace(model).get_trace(y)
trace.nodes["a1"]["value"] = torch.tensor(0.)
trace.nodes["a2"]["value"] = torch.tensor(0.)
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m8_4 = MCMC(kernel, num_samples=3000, warmup_steps=1000, num_chains=2).run(y)
precis(m8_4)
def model(y):
a1 = pyro.sample("a1", dist.Normal(0, 10))
a2 = pyro.sample("a2", dist.Normal(0, 10))
mu = a1 + a2
sigma = pyro.sample("sigma", dist.HalfCauchy(1))
with pyro.plate("plate"):
pyro.sample("y", dist.Normal(mu, sigma), obs=y)
kernel = NUTS(model, jit_compile=True)
trace = pyro.poutine.trace(model).get_trace(y)
trace.nodes["a1"]["value"] = torch.tensor(0.)
trace.nodes["a2"]["value"] = torch.tensor(0.)
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m8_4 = MCMC(kernel, num_samples=3000, warmup_steps=1000, num_chains=2).run(y)
precis(m8_4)
def model(y):
pyro.sample("a", dist.Normal(0, 1))
pyro.sample("b", dist.HalfCauchy(1))
y = torch.tensor(1.)
kernel = NUTS(model)
trace = pyro.poutine.trace(model).get_trace(y)
trace.nodes["a"]["value"] = torch.tensor(0.)
trace.nodes["b"]["value"] = torch.tensor(0.)
kernel.initial_trace = trace
mp = MCMC(kernel, num_samples=int(1e4), warmup_steps=100).run(y)
precis(mp)
N = 100 # number of individuals
height = torch.empty(N).normal_(10, 2) # sim total height of each
leg_prop = torch.empty(N).uniform_(0.4, 0.5) # leg as proportion of height
leg_left = leg_prop * height + torch.empty(N).normal_(0, 0.02)
# sim right leg as proportion + error
leg_right = leg_prop * height + torch.empty(N).normal_(0, 0.02)
# sim left leg as proportion + error
# combine into data dict
d = {"height": height, "leg_left": leg_left, "leg_right": leg_right}
def model(leg_left, leg_right, height):
a = pyro.sample("a", dist.Normal(10, 100))
bl = pyro.sample("bl", dist.Normal(2, 10))
br = pyro.sample("br", dist.Normal(2, 10))
mu = a + bl * leg_left + br * leg_right
sigma = pyro.sample("sigma", dist.HalfCauchy(1))
with pyro.plate("plate"):
pyro.sample("height", dist.Normal(mu, sigma), obs=height)
kernel = NUTS(model, jit_compile=True)
trace = pyro.poutine.trace(model).get_trace(**d)
trace.nodes["a"]["value"] = torch.tensor(10.)
trace.nodes["bl"]["value"] = torch.tensor(0.)
trace.nodes["br"]["value"] = torch.tensor(0.)
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m5_8s = MCMC(kernel, num_samples=1000, num_chains=4).run(**d)
def model(leg_left, leg_right, height):
a = pyro.sample("a", dist.Normal(10, 100))
bl = pyro.sample("bl", dist.Normal(2, 10))
br = pyro.sample("br", dist.TransformedDistribution(
dist.Normal(2, 10), transforms.ExpTransform()))
mu = a + bl * leg_left + br * leg_right
sigma = pyro.sample("sigma", dist.HalfCauchy(1))
with pyro.plate("plate"):
pyro.sample("height", dist.Normal(mu, sigma), obs=height)
kernel = NUTS(model, jit_compile=True)
trace = pyro.poutine.trace(model).get_trace(**d)
trace.nodes["a"]["value"] = torch.tensor(10.)
trace.nodes["bl"]["value"] = torch.tensor(0.)
trace.nodes["br"]["value"] = transforms.ExpTransform()(torch.tensor(0.))
trace.nodes["sigma"]["value"] = torch.tensor(1.)
kernel.initial_trace = trace
m5_8s2 = MCMC(kernel, num_samples=1000, num_chains=4).run(**d)