%load_ext autoreload
%autoreload 2
%matplotlib inline
import random
random.seed(1100038344)
import survivalstan
import numpy as np
import pandas as pd
from stancache import stancache
from matplotlib import pyplot as plt
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/Cython/Distutils/old_build_ext.py:30: UserWarning: Cython.Distutils.old_build_ext does not properly handle dependencies and is deprecated. "Cython.Distutils.old_build_ext does not properly handle dependencies " /home/jacquelineburos/.local/lib/python3.5/site-packages/IPython/html.py:14: ShimWarning: The `IPython.html` package has been deprecated. You should import from `notebook` instead. `IPython.html.widgets` has moved to `ipywidgets`. "`IPython.html.widgets` has moved to `ipywidgets`.", ShimWarning) INFO:stancache.seed:Setting seed to 1245502385
In order to demonstrate the use of this model, we will first simulate some survival data using survivalstan.sim.sim_data_exp_correlated
. As the name implies, this function simulates data assuming a constant hazard throughout the follow-up time period, which is consistent with the Exponential survival function.
This function includes two simulated covariates by default (age
and sex
). We also simulate a situation where hazard is a function of the simulated value for sex
.
We also center the age
variable since this will make it easier to interpret estimates of the baseline hazard.
d = stancache.cached(
survivalstan.sim.sim_data_exp_correlated,
N=100,
censor_time=20,
rate_form='1 + sex',
rate_coefs=[-3, 0.5],
)
d['age_centered'] = d['age'] - d['age'].mean()
INFO:stancache.stancache:sim_data_exp_correlated: cache_filename set to sim_data_exp_correlated.cached.N_100.censor_time_20.rate_coefs_54462717316.rate_form_1 + sex.pkl INFO:stancache.stancache:sim_data_exp_correlated: Loading result from cache
Aside: In order to make this a more reproducible example, this code is using a file-caching function stancache.cached
to wrap a function call to survivalstan.sim.sim_data_exp_correlated
.
Here is what these data look like - this is per-subject
or time-to-event
form:
d.head()
age | sex | rate | true_t | t | event | index | age_centered | |
---|---|---|---|---|---|---|---|---|
0 | 59 | male | 0.082085 | 20.948771 | 20.000000 | False | 0 | 4.18 |
1 | 58 | male | 0.082085 | 12.827519 | 12.827519 | True | 1 | 3.18 |
2 | 61 | female | 0.049787 | 27.018886 | 20.000000 | False | 2 | 6.18 |
3 | 57 | female | 0.049787 | 62.220296 | 20.000000 | False | 3 | 2.18 |
4 | 55 | male | 0.082085 | 10.462045 | 10.462045 | True | 4 | 0.18 |
It's not that obvious from the field names, but in this example "subjects" are indexed by the field index
.
We can plot these data using lifelines
, or the rudimentary plotting functions provided by survivalstan
.
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='female'], event_col='event', time_col='t', label='female')
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='male'], event_col='event', time_col='t', label='male')
plt.legend()
<matplotlib.legend.Legend at 0x7f64d8063978>
model_code = '''
functions {
// Defines the log survival
vector log_S (vector t, real shape, vector rate) {
vector[num_elements(t)] log_S;
for (i in 1:num_elements(t)) {
log_S[i] = gamma_lccdf(t[i]|shape,rate[i]);
}
return log_S;
}
// Defines the log hazard
vector log_h (vector t, real shape, vector rate) {
vector[num_elements(t)] log_h;
vector[num_elements(t)] ls;
ls = log_S(t,shape,rate);
for (i in 1:num_elements(t)) {
log_h[i] = gamma_lpdf(t[i]|shape,rate[i]) - ls[i];
}
return log_h;
}
// Defines the sampling distribution
real surv_gamma_lpdf (vector t, vector d, real shape, vector rate) {
vector[num_elements(t)] log_lik;
real prob;
log_lik = d .* log_h(t,shape,rate) + log_S(t,shape,rate);
prob = sum(log_lik);
return prob;
}
}
data {
int N; // number of observations
vector<lower=0>[N] y; // observed times
vector<lower=0,upper=1>[N] event; // censoring indicator (1=observed, 0=censored)
int M; // number of covariates
matrix[N, M] x; // matrix of covariates (with n rows and H columns)
}
parameters {
vector[M] beta; // Coefficients in the linear predictor (including intercept)
real<lower=0> alpha; // shape parameter
}
transformed parameters {
vector[N] linpred;
vector[N] mu;
linpred = x*beta;
for (i in 1:N) {
mu[i] = exp(linpred[i]);
}
}
model {
alpha ~ gamma(0.01,0.01);
beta ~ normal(0,5);
y ~ surv_gamma(event, alpha, mu);
}
'''
Now, we are ready to fit our model using survivalstan.fit_stan_survival_model
.
We pass a few parameters to the fit function, many of which are required. See ?survivalstan.fit_stan_survival_model for details.
Similar to what we did above, we are asking survivalstan
to cache this model fit object. See stancache for more details on how this works. Also, if you didn't want to use the cache, you could omit the parameter FIT_FUN
and survivalstan
would use the standard pystan functionality.
testfit = survivalstan.fit_stan_survival_model(
model_cohort = 'model 1',
model_code = model_code,
df = d,
time_col = 't',
event_col = 'event',
formula = '~ age_centered + sex',
iter = 5000,
chains = 4,
seed = 9001,
FIT_FUN = stancache.cached_stan_fit,
drop_intercept = False,
)
INFO:stancache.stancache:Step 1: Get compiled model code, possibly from cache INFO:stancache.stancache:StanModel: cache_filename set to anon_model.cython_0_25_1.model_code_14429915565770599621.pystan_2_12_0_0.stanmodel.pkl INFO:stancache.stancache:StanModel: Loading result from cache INFO:stancache.stancache:Step 2: Get posterior draws from model, possibly from cache INFO:stancache.stancache:sampling: cache_filename set to anon_model.cython_0_25_1.model_code_14429915565770599621.pystan_2_12_0_0.stanfit.chains_4.data_25476010973.iter_5000.seed_9001.pkl INFO:stancache.stancache:sampling: Loading result from cache
# 0:00:40.518775 elapsed
survivalstan.utils.print_stan_summary([testfit], pars=['lp__', 'alpha', 'beta'])
mean se_mean sd 2.5% 50% 97.5% Rhat lp__ -278.113380 0.021718 1.358348 -281.579969 -277.815032 -276.424942 1.000465 alpha 1.220129 0.002561 0.170488 0.913076 1.209632 1.581949 1.000245 beta[0] -2.703902 0.003699 0.226779 -3.173168 -2.695839 -2.275841 1.000271 beta[1] 0.608266 0.002897 0.199169 0.221442 0.606373 1.003776 1.000226 beta[2] 0.006018 0.000183 0.014844 -0.023068 0.005959 0.035536 1.000075
model_code2 = '''
functions {
// Defines the log survival
real surv_gamma_lpdf (vector t, vector d, real shape, vector rate) {
vector[num_elements(t)] log_lik;
real prob;
for (i in 1:num_elements(t)) {
log_lik[i] = d[i] * (gamma_lpdf(t[i]|shape,rate[i]) - gamma_lccdf(t[i]|shape,rate[i]))
+ gamma_lccdf(t[i]|shape,rate[i]);
}
prob = sum(log_lik);
return prob;
}
}
data {
int N; // number of observations
vector<lower=0>[N] y; // observed times
vector<lower=0,upper=1>[N] event; // censoring indicator (1=observed, 0=censored)
int M; // number of covariates
matrix[N, M] x; // matrix of covariates (with n rows and H columns)
}
parameters {
vector[M] beta; // Coefficients in the linear predictor (including intercept)
real<lower=0> alpha; // shape parameter
}
transformed parameters {
vector<lower=0>[N] mu;
{
vector[N] linpred;
linpred = x*beta;
mu = exp(linpred);
}
}
model {
alpha ~ gamma(0.01,0.01);
beta ~ normal(0,5);
y ~ surv_gamma(event, alpha, mu);
}
'''
testfit2 = survivalstan.fit_stan_survival_model(
model_cohort = 'model 2',
model_code = model_code2,
df = d,
time_col = 't',
event_col = 'event',
formula = '~ age_centered + sex',
iter = 5000,
chains = 4,
seed = 9001,
FIT_FUN = stancache.cached_stan_fit,
drop_intercept = False,
)
INFO:stancache.stancache:Step 1: Get compiled model code, possibly from cache INFO:stancache.stancache:StanModel: cache_filename set to anon_model.cython_0_25_1.model_code_9177012762674257483.pystan_2_12_0_0.stanmodel.pkl INFO:stancache.stancache:StanModel: Loading result from cache INFO:stancache.stancache:Step 2: Get posterior draws from model, possibly from cache INFO:stancache.stancache:sampling: cache_filename set to anon_model.cython_0_25_1.model_code_9177012762674257483.pystan_2_12_0_0.stanfit.chains_4.data_25476010973.iter_5000.seed_9001.pkl INFO:stancache.stancache:sampling: Loading result from cache
# 0:00:21.081723 elapsed
survivalstan.utils.print_stan_summary([testfit2], pars=['lp__', 'alpha', 'beta'])
mean se_mean sd 2.5% 50% 97.5% Rhat lp__ -278.122019 0.022342 1.367441 -281.523430 -277.808097 -276.440635 1.001051 alpha 1.218976 0.002590 0.171529 0.906313 1.208893 1.586105 1.000714 beta[0] -2.704073 0.003788 0.228848 -3.187193 -2.693108 -2.287256 1.000886 beta[1] 0.604867 0.002814 0.201056 0.208092 0.605508 0.993031 1.000143 beta[2] 0.006629 0.000188 0.014733 -0.021872 0.006506 0.036200 1.000063
log_mix
inside surv_gamma_lpdf¶model_code3 = '''
functions {
// Defines the log survival
real surv_gamma_lpdf (vector t, vector d, real shape, vector rate) {
vector[num_elements(t)] log_lik;
real prob;
for (i in 1:num_elements(t)) {
log_lik[i] = log_mix(d[i], gamma_lpdf(t[i]|shape,rate[i]), gamma_lccdf(t[i]|shape,rate[i]));
}
prob = sum(log_lik);
return prob;
}
}
data {
int N; // number of observations
vector<lower=0>[N] y; // observed times
vector<lower=0,upper=1>[N] event; // censoring indicator (1=observed, 0=censored)
int M; // number of covariates
matrix[N, M] x; // matrix of covariates (with n rows and H columns)
}
parameters {
vector[M] beta; // Coefficients in the linear predictor (including intercept)
real<lower=0> alpha; // shape parameter
}
transformed parameters {
vector[N] linpred;
vector[N] mu;
linpred = x*beta;
mu = exp(linpred);
}
model {
alpha ~ gamma(0.01,0.01);
beta ~ normal(0,5);
y ~ surv_gamma(event, alpha, mu);
}
'''
testfit3 = survivalstan.fit_stan_survival_model(
model_cohort = 'model 3',
model_code = model_code3,
df = d,
time_col = 't',
event_col = 'event',
formula = '~ age_centered + sex',
iter = 5000,
chains = 4,
seed = 9001,
FIT_FUN = stancache.cached_stan_fit,
drop_intercept = False,
)
INFO:stancache.stancache:Step 1: Get compiled model code, possibly from cache INFO:stancache.stancache:StanModel: cache_filename set to anon_model.cython_0_25_1.model_code_1293841621968646714.pystan_2_12_0_0.stanmodel.pkl INFO:stancache.stancache:StanModel: Loading result from cache INFO:stancache.stancache:Step 2: Get posterior draws from model, possibly from cache INFO:stancache.stancache:sampling: cache_filename set to anon_model.cython_0_25_1.model_code_1293841621968646714.pystan_2_12_0_0.stanfit.chains_4.data_25476010973.iter_5000.seed_9001.pkl INFO:stancache.stancache:sampling: Loading result from cache
#0:00:20.284146 elapsed
survivalstan.utils.print_stan_summary([testfit3], pars=['lp__', 'alpha', 'beta'])
mean se_mean sd 2.5% 50% 97.5% Rhat lp__ -278.092768 0.021055 1.344719 -281.491359 -277.783232 -276.425885 1.000812 alpha 1.216429 0.002640 0.166944 0.905732 1.207800 1.576275 1.000432 beta[0] -2.708385 0.003770 0.220966 -3.173480 -2.697460 -2.307317 1.000622 beta[1] 0.609834 0.002799 0.199563 0.223234 0.609396 1.008451 1.000295 beta[2] 0.006028 0.000180 0.014812 -0.022443 0.006036 0.036197 1.000265
model_code4 = '''
functions {
int count_value(vector a, real val) {
int s;
s = 0;
for (i in 1:num_elements(a))
if (a[i] == val)
s = s + 1;
return s;
}
// Defines the log survival
real surv_gamma_lpdf (vector t, vector d, real shape, vector rate, int num_cens, int num_obs) {
vector[2] log_lik;
int idx_obs[num_obs];
int idx_cens[num_cens];
real prob;
int i_cens;
int i_obs;
i_cens = 1;
i_obs = 1;
for (i in 1:num_elements(t)) {
if (d[i] == 1) {
idx_obs[i_obs] = i;
i_obs = i_obs+1;
}
else {
idx_cens[i_cens] = i;
i_cens = i_cens+1;
}
}
print(idx_obs);
log_lik[1] = gamma_lpdf(t[idx_obs] | shape, rate[idx_obs]);
log_lik[2] = gamma_lccdf(t[idx_cens] | shape, rate[idx_cens]);
prob = sum(log_lik);
return prob;
}
}
data {
int N; // number of observations
vector<lower=0>[N] y; // observed times
vector<lower=0,upper=1>[N] event; // censoring indicator (1=observed, 0=censored)
int M; // number of covariates
matrix[N, M] x; // matrix of covariates (with n rows and H columns)
}
transformed data {
int num_cens;
int num_obs;
num_obs = count_value(event, 1);
num_cens = N - num_obs;
}
parameters {
vector[M] beta; // Coefficients in the linear predictor (including intercept)
real<lower=0> alpha; // shape parameter
}
transformed parameters {
vector[N] linpred;
vector[N] mu;
linpred = x*beta;
mu = exp(linpred);
}
model {
alpha ~ gamma(0.01,0.01);
beta ~ normal(0,5);
y ~ surv_gamma(event, alpha, mu, num_cens, num_obs);
}
'''
testfit4 = survivalstan.fit_stan_survival_model(
model_cohort = 'model 4',
model_code = model_code4,
df = d,
time_col = 't',
event_col = 'event',
formula = '~ age_centered + sex',
iter = 5000,
chains = 4,
seed = 9001,
FIT_FUN = stancache.cached_stan_fit,
drop_intercept = False,
)
INFO:stancache.stancache:Step 1: Get compiled model code, possibly from cache INFO:stancache.stancache:StanModel: cache_filename set to anon_model.cython_0_25_1.model_code_16881928540873162731.pystan_2_12_0_0.stanmodel.pkl INFO:stancache.stancache:StanModel: Starting execution INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_b18a6495e568fcff90662e16a3d2aa85 NOW. INFO:stancache.stancache:StanModel: Execution completed (0:01:09.439292 elapsed) INFO:stancache.stancache:StanModel: Saving results to cache INFO:stancache.stancache:Step 2: Get posterior draws from model, possibly from cache INFO:stancache.stancache:sampling: cache_filename set to anon_model.cython_0_25_1.model_code_16881928540873162731.pystan_2_12_0_0.stanfit.chains_4.data_25476010973.iter_5000.seed_9001.pkl INFO:stancache.stancache:sampling: Starting execution INFO:stancache.stancache:sampling: Execution completed (0:00:06.245552 elapsed) INFO:stancache.stancache:sampling: Saving results to cache /home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stancache/stancache.py:284: UserWarning: Pickling fit objects is an experimental feature! The relevant StanModel instance must be pickled along with this fit object. When unpickling the StanModel must be unpickled first. pickle.dump(res, open(cache_filepath, 'wb'), pickle.HIGHEST_PROTOCOL)
# 0:00:06.245552 elapsed
survivalstan.utils.print_stan_summary([testfit4], pars=['lp__', 'alpha', 'beta'])
mean se_mean sd 2.5% 50% 97.5% Rhat lp__ -278.097416 0.020380 1.353386 -281.543138 -277.782205 -276.445123 0.999867 alpha 1.216149 0.002525 0.169578 0.913312 1.206908 1.574024 1.000702 beta[0] -2.708518 0.003627 0.228658 -3.188661 -2.700493 -2.278788 1.000723 beta[1] 0.610593 0.002824 0.201974 0.224438 0.607700 1.030164 1.000744 beta[2] 0.006211 0.000171 0.014493 -0.022229 0.006113 0.035095 1.000222
survivalstan.utils.plot_coefs([testfit, testfit2, testfit3, testfit4])