# Chapter 8. Markov Chain Monte Carlo¶

In [1]:
import math
import pandas as pd
import seaborn as sns
import torch
from torch.distributions import transforms

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS

from rethinking import (LM, MAP, coef, extract_samples, glimmer,


### Code 8.1¶

In [2]:
num_weeks = int(1e5)
positions = torch.tensor(0).repeat(num_weeks)
current = 9
for i in range(num_weeks):
# record current position
positions[i] = current

# flip coin to generate proposal
sample = torch.multinomial(torch.ones(2), num_samples=1)
proposal = current + torch.tensor([-1, 1])[sample].item()
# now make sure he loops around the archipelago
if proposal < 0:
proposal = 9
if proposal > 9:
proposal = 0

# move?
prob_move = (proposal + 1) / (current + 1)
current = proposal if torch.rand(1) < prob_move else current


### Code 8.2¶

In [3]:
rugged = pd.read_csv("../data/rugged.csv", sep=";")
d = rugged
d["log_gdp"] = d["rgdppc_2000"].apply(math.log)
dd = d[d["rgdppc_2000"].notnull()]


### Code 8.3¶

In [4]:
def model(rugged, cont_africa, log_gdp):
a = pyro.sample("a", dist.Normal(0, 100))
bR = pyro.sample("bR", dist.Normal(0, 10))
bA = pyro.sample("bA", dist.Normal(0, 10))
bAR = pyro.sample("bAR", dist.Normal(0, 10))
mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
sigma = pyro.sample("sigma", dist.Uniform(0, 10))
with pyro.plate("plate"):
pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)

dd.index = range(dd.shape[0])
dd_rugged = torch.tensor(dd["rugged"], dtype=torch.float)
dd_cont_africa = torch.tensor(dd["cont_africa"], dtype=torch.float)
dd_log_gdp = torch.tensor(dd["log_gdp"], dtype=torch.float)
m8_1 = MAP(model).run(dd_rugged, dd_cont_africa, dd_log_gdp)
precis(m8_1)

Out[4]:
Mean StdDev |0.89 0.89|
a 9.23 0.14 9.01 9.46
bR -0.20 0.08 -0.33 -0.09
bA -1.95 0.22 -2.30 -1.59
bAR 0.40 0.13 0.20 0.61
sigma 0.94 0.05 0.85 1.01

### Code 8.4¶

In [5]:
dd_trim = dd[["log_gdp", "rugged", "cont_africa"]]
dd_trim.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 170 entries, 0 to 169
Data columns (total 3 columns):
log_gdp        170 non-null float64
rugged         170 non-null float64
cont_africa    170 non-null int64
dtypes: float64(2), int64(1)
memory usage: 4.1 KB

Out[5]:
log_gdp rugged cont_africa
0 7.492609 0.858 1
1 8.216929 3.427 0
2 9.933263 0.769 0
3 9.407032 0.775 0
4 7.792343 2.688 0

### Code 8.5¶

In [6]:
def model(rugged, cont_africa, log_gdp):
a = pyro.sample("a", dist.Normal(0, 100))
bR = pyro.sample("bR", dist.Normal(0, 10))
bA = pyro.sample("bA", dist.Normal(0, 10))
bAR = pyro.sample("bAR", dist.Normal(0, 10))
mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
sigma = pyro.sample("sigma", dist.Uniform(0, 10))
with pyro.plate("plate"):
pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)

kernel = NUTS(model)
m8_1stan = MCMC(kernel, num_samples=1000).run(dd_rugged, dd_cont_africa, dd_log_gdp)


### Code 8.6¶

In [7]:
precis(m8_1stan)

Out[7]:
Mean StdDev |0.89 0.89| n_eff r_hat
a 9.21 0.14 8.99 9.44 231.13 1.0
bR -0.20 0.08 -0.33 -0.07 237.89 1.0
bA -1.93 0.23 -2.27 -1.54 239.02 1.0
bAR 0.38 0.13 0.17 0.56 283.33 1.0
sigma 0.95 0.05 0.86 1.03 468.30 1.0

### Code 8.7¶

In [8]:
def model(rugged, cont_africa, log_gdp):
a = pyro.sample("a", dist.Normal(0, 100))
bR = pyro.sample("bR", dist.Normal(0, 10))
bA = pyro.sample("bA", dist.Normal(0, 10))
bAR = pyro.sample("bAR", dist.Normal(0, 10))
mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
sigma = pyro.sample("sigma", dist.Uniform(0, 10))
with pyro.plate("plate"):
pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)

kernel = NUTS(model)
m8_1stan_4chains = MCMC(kernel, num_samples=1000, num_chains=4)
m8_1stan_4chains.run(dd_rugged, dd_cont_africa, dd_log_gdp)
precis(m8_1stan_4chains)

Out[8]:
Mean StdDev |0.89 0.89| n_eff r_hat
a 9.22 0.15 9.01 9.47 1063.48 1.0
bR -0.20 0.08 -0.34 -0.08 1080.07 1.0
bA -1.95 0.24 -2.35 -1.60 1060.95 1.0
bAR 0.39 0.14 0.17 0.61 1179.82 1.0
sigma 0.95 0.05 0.87 1.03 1778.89 1.0

### Code 8.8¶

In [9]:
post = extract_samples(m8_1stan)
{latent: post[latent][:5] for latent in post}

Out[9]:
{'a': tensor([9.2374, 9.3048, 9.1149, 9.0362, 9.3176]),
'bR': tensor([-0.2022, -0.3170, -0.1933, -0.1671, -0.1185]),
'bA': tensor([-2.1425, -2.0199, -1.9134, -1.8863, -2.1787]),
'bAR': tensor([0.4066, 0.4874, 0.3268, 0.4041, 0.3728]),
'sigma': tensor([0.8741, 0.9995, 1.0940, 1.0770, 0.9178])}

### Code 8.9¶

In [10]:
sns.pairplot(pd.DataFrame(post), diag_kind="kde",
plot_kws={"edgecolor": "none", "alpha": 0.2});


### Code 8.10¶

In [11]:
post = extract_samples(m8_1stan)
sns.pairplot(pd.DataFrame(post), diag_kind="kde",
plot_kws={"edgecolor": "none", "alpha": 0.2});


### Code 8.12¶

In [12]:
m8_1stan.information_criterion()

Out[12]:
OrderedDict([('waic', tensor(469.6760)), ('p_waic', tensor(5.3495))])

### Code 8.13¶

In [13]:
post = extract_samples(m8_1stan)
precis_df = precis(m8_1stan)
for latent in post:
ax = sns.lineplot(range(1000), post[latent])
ax.set(title="{}  |  n_eff = {}"
.format(latent, int(precis_df.loc[latent, "n_eff"])))
sns.mpl.pyplot.show()