%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
import pystan
import statsmodels.api as sm
import statsmodels.formula.api as smf
from scipy.stats import logistic, distributions as dst
import rpy2
np.random.seed(1)
pystan.__version__
'2.12.0.0'
%load_ext rpy2.ipython
%%R -o data
# Simulate a single trial with sample size n
sim <- function(n) {
trt <- c(rep('A', n / 2), rep('B', n / 2))
sbp0 <- rnorm(n, 140, 7)
sbp <- sbp0 - 5 - 3 * (trt == 'B') + rnorm(n, sd=7)
logit <- -2.6 + log(0.8) * (trt == 'B') + 0.05 * (sbp0 - 140) +
0.05 * (sbp - 130)
ds <- ifelse(runif(n) <= plogis(logit), 1, 0)
data.frame(trt, sbp0, sbp, ds)
}
set.seed(7)
data <- sim(n=1500)
data.head()
trt | sbp0 | sbp | ds | |
---|---|---|---|---|
1 | A | 156.010730 | 145.740946 | 0.0 |
2 | A | 131.622598 | 125.179766 | 0.0 |
3 | A | 135.139952 | 128.711171 | 0.0 |
4 | A | 137.113949 | 127.264851 | 0.0 |
5 | A | 133.205287 | 131.413852 | 0.0 |
data.assign(treat=(data.trt=='B').astype(int)).plot.scatter('sbp0', 'sbp', c='treat', cmap='plasma', alpha=0.4);
model = """
data {
int n;
vector[n] x;
real y1[n];
int y2[n];
vector[n] treat;
vector[2] Zero;
vector<lower=0>[2] sigma_b;
}
parameters {
vector[2] alpha;
vector[2] beta;
vector[2] mu;
real<lower=0> sigma_y;
cholesky_factor_corr[2] L_b;
}
transformed parameters {
vector[n] theta1;
vector[n] theta2;
theta1 = mu[1] + alpha[1]*x + beta[1]*treat;
theta2 = mu[2] + alpha[2]*x + beta[2]*treat;
}
model {
beta ~ multi_normal_cholesky(Zero, diag_pre_multiply(sigma_b, L_b));
L_b ~ lkj_corr_cholesky(1); // correlation matrix for reg. parameters, LKJ prior
y1 ~ normal(theta1, sigma_y);
y2 ~ bernoulli_logit(theta2);
}
generated quantities {
matrix[2,2] Omega;
matrix[2,2] Sigma;
Omega <- multiply_lower_tri_self_transpose(L_b);
Sigma <- quad_form_diag(Omega, sigma_b);
}
"""
fit = pystan.stan(model_code=model, seed=7, n_jobs=1,
data=dict(x=data.sbp0.values,
y1=data.sbp.values,
y2=data.ds.astype(int).values,
treat=(data.trt=='B').astype(int).values,
n=data.shape[0],
sigma_b=(-10/dst.norm.ppf(0.1), np.log(0.5)/dst.norm.ppf(0.05)),
Zero=np.zeros(2)))
INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_e939446386a69132052970fed98900ee NOW.
fit.plot(pars=['beta']);
output = fit.extract(permuted=True)
Look at the correlations of the posterior means of the thetas!
sns.jointplot(output['theta1'].mean(0), output['theta2'].mean(0), kind='kde')
/Users/fonnescj/anaconda3/envs/dev/lib/python3.5/site-packages/statsmodels/nonparametric/kdetools.py:20: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future y = X[:m/2+1] + np.r_[0,X[m/2+1:],0]*1j
<seaborn.axisgrid.JointGrid at 0x12711e390>
Meanwhile, here are the posteriors of the regression parameters: beta
, mu
and alpha
sns.jointplot(*output['beta'].T, kind='kde')
/Users/fonnescj/anaconda3/envs/dev/lib/python3.5/site-packages/statsmodels/nonparametric/kdetools.py:20: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future y = X[:m/2+1] + np.r_[0,X[m/2+1:],0]*1j
<seaborn.axisgrid.JointGrid at 0x11a8b4668>
sns.jointplot(*output['alpha'].T, kind='kde')
/Users/fonnescj/anaconda3/envs/dev/lib/python3.5/site-packages/statsmodels/nonparametric/kdetools.py:20: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future y = X[:m/2+1] + np.r_[0,X[m/2+1:],0]*1j
<seaborn.axisgrid.JointGrid at 0x11bc682e8>
sns.jointplot(*output['mu'].T, kind='kde')
/Users/fonnescj/anaconda3/envs/dev/lib/python3.5/site-packages/statsmodels/nonparametric/kdetools.py:20: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future y = X[:m/2+1] + np.r_[0,X[m/2+1:],0]*1j
<seaborn.axisgrid.JointGrid at 0x12712be48>
output['Omega'].mean(0)
array([[ 1. , 0.18826785], [ 0.18826785, 1. ]])
output['Omega'].std(0)
array([[ 0.00000000e+00, 5.56033789e-01], [ 5.56033789e-01, 7.09587996e-17]])
output['Sigma'].mean(0)
array([[ 60.88745604, 0.61906789], [ 0.61906789, 0.17758096]])