Chapter 8. Markov Chain Monte Carlo

In [1]:
import math
import pandas as pd
import seaborn as sns
import torch
from torch.distributions import transforms

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS

from rethinking import (LM, MAP, coef, extract_samples, glimmer,
                        link, precis, replicate, sim, vcov)

Code 8.1

In [2]:
num_weeks = int(1e5)
positions = torch.tensor(0).repeat(num_weeks)
current = 9
for i in range(num_weeks):
    # record current position
    positions[i] = current

    # flip coin to generate proposal
    sample = torch.multinomial(torch.ones(2), num_samples=1)
    proposal = current + torch.tensor([-1, 1])[sample].item()
    # now make sure he loops around the archipelago
    if proposal < 0:
        proposal = 9
    if proposal > 9:
        proposal = 0
    
    # move?
    prob_move = (proposal + 1) / (current + 1)
    current = proposal if torch.rand(1) < prob_move else current

Code 8.2

In [3]:
rugged = pd.read_csv("../data/rugged.csv", sep=";")
d = rugged
d["log_gdp"] = d["rgdppc_2000"].apply(math.log)
dd = d[d["rgdppc_2000"].notnull()]

Code 8.3

In [4]:
def model(rugged, cont_africa, log_gdp):
    a = pyro.sample("a", dist.Normal(0, 100))
    bR = pyro.sample("bR", dist.Normal(0, 10))
    bA = pyro.sample("bA", dist.Normal(0, 10))
    bAR = pyro.sample("bAR", dist.Normal(0, 10))
    mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
    sigma = pyro.sample("sigma", dist.Uniform(0, 10))
    with pyro.plate("plate"):
        pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)

dd.index = range(dd.shape[0])
dd_rugged = torch.tensor(dd["rugged"], dtype=torch.float)
dd_cont_africa = torch.tensor(dd["cont_africa"], dtype=torch.float)
dd_log_gdp = torch.tensor(dd["log_gdp"], dtype=torch.float)
m8_1 = MAP(model).run(dd_rugged, dd_cont_africa, dd_log_gdp)
precis(m8_1)
Out[4]:
Mean StdDev |0.89 0.89|
a 9.23 0.14 9.01 9.46
bR -0.20 0.08 -0.33 -0.09
bA -1.95 0.22 -2.30 -1.59
bAR 0.40 0.13 0.20 0.61
sigma 0.94 0.05 0.85 1.01

Code 8.4

In [5]:
dd_trim = dd[["log_gdp", "rugged", "cont_africa"]]
dd_trim.info()
dd_trim.head()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 170 entries, 0 to 169
Data columns (total 3 columns):
log_gdp        170 non-null float64
rugged         170 non-null float64
cont_africa    170 non-null int64
dtypes: float64(2), int64(1)
memory usage: 4.1 KB
Out[5]:
log_gdp rugged cont_africa
0 7.492609 0.858 1
1 8.216929 3.427 0
2 9.933263 0.769 0
3 9.407032 0.775 0
4 7.792343 2.688 0

Code 8.5

In [6]:
def model(rugged, cont_africa, log_gdp):
    a = pyro.sample("a", dist.Normal(0, 100))
    bR = pyro.sample("bR", dist.Normal(0, 10))
    bA = pyro.sample("bA", dist.Normal(0, 10))
    bAR = pyro.sample("bAR", dist.Normal(0, 10))
    mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
    sigma = pyro.sample("sigma", dist.Uniform(0, 10))
    with pyro.plate("plate"):
        pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)

kernel = NUTS(model)
m8_1stan = MCMC(kernel, num_samples=1000).run(dd_rugged, dd_cont_africa, dd_log_gdp)

Code 8.6

In [7]:
precis(m8_1stan)
Out[7]:
Mean StdDev |0.89 0.89| n_eff r_hat
a 9.21 0.14 8.99 9.44 231.13 1.0
bR -0.20 0.08 -0.33 -0.07 237.89 1.0
bA -1.93 0.23 -2.27 -1.54 239.02 1.0
bAR 0.38 0.13 0.17 0.56 283.33 1.0
sigma 0.95 0.05 0.86 1.03 468.30 1.0

Code 8.7

In [8]:
def model(rugged, cont_africa, log_gdp):
    a = pyro.sample("a", dist.Normal(0, 100))
    bR = pyro.sample("bR", dist.Normal(0, 10))
    bA = pyro.sample("bA", dist.Normal(0, 10))
    bAR = pyro.sample("bAR", dist.Normal(0, 10))
    mu = a + bR * rugged + bA * cont_africa + bAR * rugged * cont_africa
    sigma = pyro.sample("sigma", dist.Uniform(0, 10))
    with pyro.plate("plate"):
        pyro.sample("log_gdp", dist.Normal(mu, sigma), obs=log_gdp)

kernel = NUTS(model)
m8_1stan_4chains = MCMC(kernel, num_samples=1000, num_chains=4)
m8_1stan_4chains.run(dd_rugged, dd_cont_africa, dd_log_gdp)
precis(m8_1stan_4chains)
Out[8]:
Mean StdDev |0.89 0.89| n_eff r_hat
a 9.22 0.15 9.01 9.47 1063.48 1.0
bR -0.20 0.08 -0.34 -0.08 1080.07 1.0
bA -1.95 0.24 -2.35 -1.60 1060.95 1.0
bAR 0.39 0.14 0.17 0.61 1179.82 1.0
sigma 0.95 0.05 0.87 1.03 1778.89 1.0

Code 8.8

In [9]:
post = extract_samples(m8_1stan)
{latent: post[latent][:5] for latent in post}
Out[9]:
{'a': tensor([9.2374, 9.3048, 9.1149, 9.0362, 9.3176]),
 'bR': tensor([-0.2022, -0.3170, -0.1933, -0.1671, -0.1185]),
 'bA': tensor([-2.1425, -2.0199, -1.9134, -1.8863, -2.1787]),
 'bAR': tensor([0.4066, 0.4874, 0.3268, 0.4041, 0.3728]),
 'sigma': tensor([0.8741, 0.9995, 1.0940, 1.0770, 0.9178])}

Code 8.9

In [10]:
sns.pairplot(pd.DataFrame(post), diag_kind="kde",
             plot_kws={"edgecolor": "none", "alpha": 0.2});

Code 8.10

In [11]:
post = extract_samples(m8_1stan)
sns.pairplot(pd.DataFrame(post), diag_kind="kde",
             plot_kws={"edgecolor": "none", "alpha": 0.2});

Code 8.12

In [12]:
m8_1stan.information_criterion()
Out[12]:
OrderedDict([('waic', tensor(469.6760)), ('p_waic', tensor(5.3495))])

Code 8.13

In [13]:
post = extract_samples(m8_1stan)
precis_df = precis(m8_1stan)
for latent in post:
    ax = sns.lineplot(range(1000), post[latent])
    ax.set(title="{}  |  n_eff = {}"
           .format(latent, int(precis_df.loc[latent, "n_eff"])))
    sns.mpl.pyplot.show()