Stan implements Hamiltonian Monte Carlo with a No U Turn (NUTS) sampler. We treat this as a black box and focus on using Stan. While the algorithmic details are more complicated than Gibbs or Metropolis, many of the major ideas are the same:
To demonstrate STAN we do three examples:
import pystan
import numpy as np
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')
Models are written in STAN code (see STAN Language Guide for full specification). The language is the same regardless of whether you are calling STAN from python or R.
We fit the normal conjugate model: a normal distribution with an unknown mean, known variance, and normal prior on the mean.
normalconj = """
data {
int<lower=0> N; // number of observations
real y[N]; // observations
}
parameters {
real mu;
}
model {
mu ~ normal(0, 10); // prior on mean, 10 is the *sd*
y ~ normal(mu, 1); // likelihood
}
"""
The model must be compiled into C++ code. This can take a while. However this only need to be done once per model. Thus you can fit new data sets or additional chains with the compiled model sm below.
sm = pystan.StanModel(model_code=normalconj)
INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_c2c6db8fbfd7f5ebb758fa012881d1b6 NOW.
The data we fit the model on may be stored as a python dictionary with keys the variable names and values the data.
normal_dat = {'N': 5,
'y': [6.4,6.1,5.0,5.8,5.88]}
fit = sm.sampling(data=normal_dat, iter=1000, chains=4)
/usr/local/lib/python3.5/dist-packages/pystan/misc.py:360: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`. if np.issubdtype(np.asarray(v).dtype, int): /usr/local/lib/python3.5/dist-packages/pystan/misc.py:362: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. elif np.issubdtype(np.asarray(v).dtype, float):
print(fit)
Inference for Stan model: anon_model_c2c6db8fbfd7f5ebb758fa012881d1b6. 4 chains, each with iter=1000; warmup=500; thin=1; post-warmup draws per chain=500, total post-warmup draws=2000. mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat mu 5.82 0.01 0.42 5.05 5.53 5.8 6.11 6.67 1434 1.0 lp__ -1.15 0.02 0.61 -2.85 -1.31 -0.92 -0.77 -0.72 1070 1.0 Samples were drawn using NUTS at Tue Apr 10 13:01:19 2018. For each parameter, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence, Rhat=1).
We obtain the chains by calling the method extract
. The permuted argument tells STAN to return chains with the sampling order permuted. This eliminates correlation between adjacent samples.
la = fit.extract(permuted=True) # return a dictionary of arrays
mu = la['mu']
np.mean(mu)
5.824647437520733
## posterior mean (for each chain)
fit.get_posterior_mean()
array([[ 5.84010739, 5.82631721, 5.80810106, 5.82406408], [-1.16478986, -1.20447218, -1.0952655 , -1.12838365]])
fit.plot();
From statistical theory we know that the posterior is normal with parameters \begin{align*} \mu &= \left(\frac{1}{\frac{1}{10^2} + \frac{5}{1^2}}\right) \sum y_i\\ \sigma^2 &= \frac{1}{\frac{1}{10^2} + \frac{5}{1^2}} \end{align*} See Conjugate Priors Continuous distributions Table first row for result.
y = np.array(normal_dat['y'])
mupost = (1.0 / (1/10**2 + 5.0/1))*(np.sum(y))
sigma2post = 1.0/(1.0/10.0**2 + 5.0 / 1.0)
from scipy.stats import norm
We plot draws from the posterior (mu) against the density derived using conjugacy.
plt.hist(mu, bins=30, normed=True, alpha=0.5,
histtype='stepfilled', color='steelblue',
edgecolor='none');
mus = np.linspace(4.0,7.5,100);
plt.plot(mus,norm.pdf(mus,mupost,np.sqrt(sigma2post)));
We rerun the logistic regression example from earlier lectures. First we code up the one predictor regression model plus intercept, taken directly from the Stan Reference guide.
logistic = """
data {
int<lower=0> N;
vector[N] x;
int<lower=0,upper=1> y[N];
}
parameters {
real alpha;
real beta;
}
model {
y ~ bernoulli_logit(alpha + beta * x); // by not specifying prior, assumed flat, i.e. p(alpha,beta)=1
}
"""
sm = pystan.StanModel(model_code=logistic)
INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_4f1364be733844df2bd1b8d4168052ea NOW.
x
is dosage and y
indicates death.
x = np.array([-0.86,-0.86,-0.86,-0.86,-0.86,
-0.3,-0.3,-0.3,-0.3,-0.3,
-0.05,-0.05,-0.05,-0.05,-0.05,
0.73,0.73,0.73,0.73,0.73]) ## log scale
y = np.array([0,0,0,0,0,1,0,0,0,0,1,1,1,0,0,1,1,1,1,1])
x
array([-0.86, -0.86, -0.86, -0.86, -0.86, -0.3 , -0.3 , -0.3 , -0.3 , -0.3 , -0.05, -0.05, -0.05, -0.05, -0.05, 0.73, 0.73, 0.73, 0.73, 0.73])
y
array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1])
np.size(y)
20
logistic_dat = {'N': np.size(y),
'x': x,
'y': y}
fit = sm.sampling(data=logistic_dat, iter=1000, chains=4)
/usr/local/lib/python3.5/dist-packages/pystan/misc.py:360: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`. if np.issubdtype(np.asarray(v).dtype, int): /usr/local/lib/python3.5/dist-packages/pystan/misc.py:362: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. elif np.issubdtype(np.asarray(v).dtype, float):
print(fit)
Inference for Stan model: anon_model_4f1364be733844df2bd1b8d4168052ea. 4 chains, each with iter=1000; warmup=500; thin=1; post-warmup draws per chain=500, total post-warmup draws=2000. mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat alpha 1.24 0.05 1.06 -0.6 0.5 1.18 1.88 3.67 520 1.0 beta 11.24 0.24 5.22 3.49 7.32 10.36 14.47 23.54 457 1.01 lp__ -6.94 0.05 1.05 -9.6 -7.32 -6.63 -6.19 -5.92 524 1.0 Samples were drawn using NUTS at Tue Apr 10 13:04:20 2018. For each parameter, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence, Rhat=1).
These $\alpha$ and $\beta$ posterior means match closely to our Metropolis chain which obtained values of 1.27 and 11.71 respectively.
We now obtain the entire $\alpha$ and $\beta$ chains.
la = fit.extract(permuted=True) # return a dictionary of arrays
alpha = la['alpha']
beta = la['beta']
plt.plot(alpha, beta, '.', color='black',alpha=0.3);
plt.xlabel("alpha");
plt.ylabel("beta");
Recall that using Metropolis on this example we had to tune the jump size. STAN takes care of this for us.
We discussed plotting the time series of iterates of multiple chains to monitor for mixing and stationarity. STAN also outputs Rhat
, a comparison of between chain parameter variance to within chain parameter variance. Rhat
is always greater than $1$. Lower values suggest better mixing. Rhat > 1.1 suggests running the chain longer could improve mixing. See BDA by Gelman Section 11.4 for details.
print(fit)
Inference for Stan model: anon_model_4f1364be733844df2bd1b8d4168052ea. 4 chains, each with iter=1000; warmup=500; thin=1; post-warmup draws per chain=500, total post-warmup draws=2000. mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat alpha 1.24 0.05 1.06 -0.6 0.5 1.18 1.88 3.67 520 1.0 beta 11.24 0.24 5.22 3.49 7.32 10.36 14.47 23.54 457 1.01 lp__ -6.94 0.05 1.05 -9.6 -7.32 -6.63 -6.19 -5.92 524 1.0 Samples were drawn using NUTS at Tue Apr 10 13:04:20 2018. For each parameter, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence, Rhat=1).
The Rhat values look good.
We now discuss a famous hierarchical models for school performance.
A = np.array([62.0,60.0,63.0,59.0])
B = np.array([63.0,67.0,71.0,64.0,65.0,66.0])
C = np.array([68.0,66.0,71.0,67.0,68.0,68.0])
D = np.array([56.,62.,60.,61.,63.,64.,63.,59.])
A.size
4
coag_dat = {'N': 4,
'y': [np.mean(A),np.mean(B),np.mean(C),np.mean(D)],
'ns': [A.size, B.size, C.size, D.size]}
coag_dat
{'N': 4, 'ns': [4, 6, 6, 8], 'y': [61.0, 66.0, 68.0, 61.0]}
coag_code = """
data {
int<lower=0> N; // number of diets
real y[N]; // means coagulation within diet
int<lower=0> ns[N]; // square root number measurements each group
}
parameters {
real mu; // population mean
real theta[N]; // mean for particular diet
real<lower=0> tau; // s.d. for each diet mean prior
real<lower=0> sigma; // s.d. on y
}
transformed parameters {
real<lower=0> sigmat[N];
for (n in 1:N)
sigmat[n] = sigma / sqrt(ns[n]);
}
model {
theta ~ normal(mu, tau);
y ~ normal(theta, sigmat);
target += log(1.0/sigma); // custom 1/sigma prior on sigma
}
"""
sm = pystan.StanModel(model_code=coag_code)
INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_b48d584fb6bb1e9847dde34d3b4b7bac NOW.
fit = sm.sampling(data=coag_dat, iter=10000, chains=20)
/usr/local/lib/python3.5/dist-packages/pystan/misc.py:360: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`. if np.issubdtype(np.asarray(v).dtype, int): /usr/local/lib/python3.5/dist-packages/pystan/misc.py:362: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. elif np.issubdtype(np.asarray(v).dtype, float):
print(fit)
Inference for Stan model: anon_model_b48d584fb6bb1e9847dde34d3b4b7bac. 20 chains, each with iter=10000; warmup=5000; thin=1; post-warmup draws per chain=5000, total post-warmup draws=100000. mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat mu 64.09 0.2 6.71 54.25 62.16 63.89 65.67 75.5 1124 1.02 theta[0] 61.77 0.05 4.01 56.21 60.69 61.15 62.5 69.49 5511 1.0 theta[1] 65.63 0.04 3.53 59.22 65.09 65.93 66.34 70.93 9356 1.0 theta[2] 67.21 0.04 3.77 60.11 66.56 67.84 68.2 72.16 8304 1.0 theta[3] 61.6 0.05 3.36 57.06 60.78 61.1 62.03 68.17 4492 1.0 tau 7.34 0.27 12.88 0.97 3.09 4.81 7.79 31.01 2243 1.01 sigma 5.99 0.27 10.81 0.24 0.96 2.71 7.25 29.57 1599 1.01 sigmat[0] 2.99 0.14 5.41 0.12 0.48 1.35 3.62 14.79 1599 1.01 sigmat[1] 2.44 0.11 4.41 0.1 0.39 1.1 2.96 12.07 1599 1.01 sigmat[2] 2.44 0.11 4.41 0.1 0.39 1.1 2.96 12.07 1599 1.01 sigmat[3] 2.12 0.1 3.82 0.09 0.34 0.96 2.56 10.46 1599 1.01 lp__ -8.49 0.33 5.46 -20.07 -11.85 -8.5 -4.62 1.63 268 1.08 Samples were drawn using NUTS at Tue Apr 10 14:24:50 2018. For each parameter, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence, Rhat=1).
check_treedepth(fit)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) <ipython-input-84-01f87e08fbf6> in <module>() ----> 1 check_treedepth(fit) NameError: name 'check_treedepth' is not defined
The above output reproduces parts of Table 11.3 in BDA. A Gibbs sampler here is more efficient. However Stan will work with non--conditionallly conjugate priors.
Effective use of MCMC and Stan are a bit of an art, especially for challenging examples where convergence issues can occur.