%load_ext autoreload
%autoreload 2
%matplotlib inline
import random
random.seed(13847942484)
import survivalstan
import numpy as np
import pandas as pd
from stancache import stancache
from matplotlib import pyplot as plt
INFO:stancache.seed:Setting seed to 1245502385
print(survivalstan.models.pem_survival_model_timevarying)
/* Variable naming: // dimensions N = total number of observations (length of data) S = number of sample ids T = max timepoint (number of timepoint ids) M = number of covariates // data s = sample id for each obs t = timepoint id for each obs event = integer indicating if there was an event at time t for sample s x = matrix of real-valued covariates at time t for sample n [N, X] obs_t = observed end time for interval for timepoint for that obs */ // Jacqueline Buros Novik <jackinovik@gmail.com> functions { matrix spline(vector x, int N, int H, vector xi, int P) { matrix[N, H + P] b_x; // expanded predictors for (n in 1:N) { for (p in 1:P) { b_x[n,p] <- pow(x[n],p-1); // x[n]^(p-1) } for (h in 1:H) b_x[n, h + P] <- fmax(0, pow(x[n] - xi[h],P-1)); } return b_x; } } data { // dimensions int<lower=1> N; int<lower=1> S; int<lower=1> T; int<lower=0> M; // data matrix int<lower=1, upper=N> s[N]; // sample id int<lower=1, upper=T> t[N]; // timepoint id int<lower=0, upper=1> event[N]; // 1: event, 0:censor matrix[N, M] x; // explanatory vars // timepoint data vector<lower=0>[T] t_obs; vector<lower=0>[T] t_dur; } transformed data { vector[T] log_t_dur; int n_trans[S, T]; log_t_dur = log(t_dur); // n_trans used to map each sample*timepoint to n (used in gen quantities) // map each patient/timepoint combination to n values for (n in 1:N) { n_trans[s[n], t[n]] = n; } // fill in missing values with n for max t for that patient // ie assume "last observed" state applies forward (may be problematic for TVC) // this allows us to predict failure times >= observed survival times for (samp in 1:S) { int last_value; last_value = 0; for (tp in 1:T) { // manual says ints are initialized to neg values // so <=0 is a shorthand for "unassigned" if (n_trans[samp, tp] <= 0 && last_value != 0) { n_trans[samp, tp] = last_value; } else { last_value = n_trans[samp, tp]; } } } } parameters { vector[T] log_baseline_raw; // unstructured baseline hazard for each timepoint t real<lower=0> baseline_sigma; real log_baseline_mu; vector[M] beta; // beta-intercept vector<lower=0>[M] beta_time_sigma; vector[T-1] raw_beta_time_deltas[M]; // for each coefficient // change in coefficient value from previous time } transformed parameters { vector[N] log_hazard; vector[T] log_baseline; vector[T] beta_time[M]; vector[T] beta_time_deltas[M]; // adjust baseline hazard for duration of each period log_baseline = log_baseline_raw + log_t_dur; // compute timepoint-specific betas // offsets from previous time for (coef in 1:M) { beta_time_deltas[coef][1] = 0; for (time in 2:T) { beta_time_deltas[coef][time] = raw_beta_time_deltas[coef][time-1]; } } // coefficients for each timepoint T for (coef in 1:M) { beta_time[coef] = beta[coef] + cumulative_sum(beta_time_deltas[coef]); } // compute log-hazard for each obs for (n in 1:N) { real log_linpred; log_linpred <- 0; for (coef in 1:M) { // for now, handle each coef separately // (to be sure we pull out the "right" beta..) log_linpred = log_linpred + x[n, coef] * beta_time[coef][t[n]]; } log_hazard[n] = log_baseline_mu + log_baseline[t[n]] + log_linpred; } } model { // priors on time-varying coefficients for (m in 1:M) { raw_beta_time_deltas[m][1] ~ normal(0, 100); for(i in 2:(T-1)){ raw_beta_time_deltas[m][i] ~ normal(raw_beta_time_deltas[m][i-1], beta_time_sigma[m]); } } beta_time_sigma ~ cauchy(0, 1); beta ~ cauchy(0, 1); // priors on baseline hazard log_baseline_mu ~ normal(0, 1); baseline_sigma ~ normal(0, 1); log_baseline_raw[1] ~ normal(0, 1); for (i in 2:T) { log_baseline_raw[i] ~ normal(log_baseline_raw[i-1], baseline_sigma); } // model event ~ poisson_log(log_hazard); } generated quantities { real log_lik[N]; vector[T] baseline; int y_hat_mat[S, T]; // ppcheck for each S*T combination real y_hat_time[S]; // predicted failure time for each sample int y_hat_event[S]; // predicted event (0:censor, 1:event) // compute raw baseline hazard, for summary/plotting baseline = exp(log_baseline_raw); // log_likelihood for loo-psis for (n in 1:N) { log_lik[n] <- poisson_log_lpmf(event[n] | log_hazard[n]); } // posterior predicted values for (samp in 1:S) { int sample_alive; sample_alive = 1; for (tp in 1:T) { if (sample_alive == 1) { int n; int pred_y; real log_linpred; real log_haz; // determine predicted value of y n = n_trans[samp, tp]; // (borrow code from above to calc linpred) // but use sim tp not t[n] log_linpred = 0; for (coef in 1:M) { // for now, handle each coef separately // (to be sure we pull out the "right" beta..) log_linpred = log_linpred + x[n, coef] * beta_time[coef][tp]; } log_haz = log_baseline_mu + log_baseline[tp] + log_linpred; // now, make posterior prediction if (log_haz < log(pow(2, 30))) pred_y = poisson_log_rng(log_haz); else pred_y = 9; // mark this patient as ineligible for future tps // note: deliberately make 9s ineligible if (pred_y >= 1) { sample_alive = 0; y_hat_time[samp] = t_obs[tp]; y_hat_event[samp] = 1; } // save predicted value of y to matrix y_hat_mat[samp, tp] = pred_y; } else if (sample_alive == 0) { y_hat_mat[samp, tp] = 9; } } // end per-timepoint loop // if patient still alive at max // if (sample_alive == 1) { y_hat_time[samp] = t_obs[T]; y_hat_event[samp] = 0; } } // end per-sample loop }
d = stancache.cached(
survivalstan.sim.sim_data_exp_correlated,
N=100,
censor_time=20,
rate_form='1 + sex',
rate_coefs=[-3, 0.5],
)
d['age_centered'] = d['age'] - d['age'].mean()
d.head()
INFO:stancache.stancache:sim_data_exp_correlated: cache_filename set to sim_data_exp_correlated.cached.N_100.censor_time_20.rate_coefs_54462717316.rate_form_1 + sex.pkl INFO:stancache.stancache:sim_data_exp_correlated: Loading result from cache
sex | age | rate | true_t | t | event | index | age_centered | |
---|---|---|---|---|---|---|---|---|
0 | male | 54 | 0.082085 | 1.013855 | 1.013855 | True | 0 | -1.12 |
1 | male | 39 | 0.082085 | 4.890597 | 4.890597 | True | 1 | -16.12 |
2 | female | 45 | 0.049787 | 4.093404 | 4.093404 | True | 2 | -10.12 |
3 | female | 43 | 0.049787 | 7.036226 | 7.036226 | True | 3 | -12.12 |
4 | female | 57 | 0.049787 | 5.712299 | 5.712299 | True | 4 | 1.88 |
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='female'], event_col='event', time_col='t', label='female')
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='male'], event_col='event', time_col='t', label='male')
plt.legend()
<matplotlib.legend.Legend at 0x7f9d6c4069e8>
dlong = stancache.cached(
survivalstan.prep_data_long_surv,
df=d, event_col='event', time_col='t'
)
dlong.sort_values(['index', 'end_time'], inplace=True)
INFO:stancache.stancache:prep_data_long_surv: cache_filename set to prep_data_long_surv.cached.df_14209590808.event_col_event.time_col_t.pkl INFO:stancache.stancache:prep_data_long_surv: Loading result from cache
dlong.head()
sex | age | rate | true_t | t | event | index | age_centered | end_time | end_failure | |
---|---|---|---|---|---|---|---|---|---|---|
73 | male | 54 | 0.082085 | 1.013855 | 1.013855 | True | 0 | -1.12 | 0.009787 | False |
65 | male | 54 | 0.082085 | 1.013855 | 1.013855 | True | 0 | -1.12 | 0.377535 | False |
72 | male | 54 | 0.082085 | 1.013855 | 1.013855 | True | 0 | -1.12 | 0.791192 | False |
58 | male | 54 | 0.082085 | 1.013855 | 1.013855 | True | 0 | -1.12 | 0.808987 | False |
0 | male | 54 | 0.082085 | 1.013855 | 1.013855 | True | 0 | -1.12 | 1.013855 | True |
testfit = survivalstan.fit_stan_survival_model(
model_cohort = 'test model',
model_code = survivalstan.models.pem_survival_model_timevarying,
df = dlong,
sample_col = 'index',
timepoint_end_col = 'end_time',
event_col = 'end_failure',
formula = '~ age_centered + sex',
iter = 10000,
chains = 4,
seed = 9001,
FIT_FUN = stancache.cached_stan_fit,
)
INFO:stancache.stancache:Step 1: Get compiled model code, possibly from cache INFO:stancache.stancache:StanModel: cache_filename set to anon_model.cython_0_29_2.model_code_10216236489136838232.pystan_2_18_1_0.stanmodel.pkl INFO:stancache.stancache:StanModel: Loading result from cache INFO:stancache.stancache:Step 2: Get posterior draws from model, possibly from cache INFO:stancache.stancache:sampling: cache_filename set to anon_model.cython_0_29_2.model_code_10216236489136838232.pystan_2_18_1_0.stanfit.chains_4.data_98562805320.iter_10000.seed_9001.pkl INFO:stancache.stancache:sampling: Loading result from cache
survivalstan.utils.print_stan_summary([testfit], pars='lp__')
mean se_mean sd 2.5% 50% 97.5% Rhat lp__ 438.367242 33.037071 136.999711 60.442022 454.291446 665.11812 1.245577
survivalstan.utils.plot_stan_summary([testfit], pars='log_baseline')
survivalstan.utils.plot_coefs([testfit], element='baseline')
survivalstan.utils.plot_coefs([testfit])
survivalstan.utils.plot_pp_survival([testfit], fill=False)
survivalstan.utils.plot_observed_survival(df=d, event_col='event', time_col='t', color='green', label='observed')
plt.legend()
<matplotlib.legend.Legend at 0x7f9c947dcb00>
survivalstan.utils.plot_pp_survival([testfit], by='sex')
survivalstan.utils.plot_pp_survival([testfit], by='sex', pal=['red', 'blue'])
Standard behavior is to plot estimated betas at each timepoint, for each coefficient in the model.
survivalstan.utils.plot_coefs([testfit], element='beta_time', ylim=[-1, 2.5])
survivalstan.utils.plot_time_betas(models=[testfit], by=['coef'], y='exp(beta)', ylim=[0, 10])
Alternatively, you can extract the beta-estimates for each timepoint & plot them yourself.
testfit['time_beta'] = survivalstan.utils.extract_time_betas([testfit])
testfit['time_beta'].head()
iter | _timepoint_id | beta | coef | end_time | exp(beta) | model_cohort | |
---|---|---|---|---|---|---|---|
0 | 0 | 1 | 1.303891 | sex | 0.009787 | 3.683601 | test model |
1 | 1 | 1 | 0.349997 | sex | 0.009787 | 1.419063 | test model |
2 | 2 | 1 | 0.348042 | sex | 0.009787 | 1.416292 | test model |
3 | 3 | 1 | 0.200147 | sex | 0.009787 | 1.221583 | test model |
4 | 4 | 1 | 0.189725 | sex | 0.009787 | 1.208917 | test model |
You can also extract and/or plot data for single coefficients of interest at a time.
first_beta = survivalstan.utils.extract_time_betas([testfit], coefs=['sex[T.male]'])
first_beta.head()
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-18-9167c46e7967> in <module> ----> 1 first_beta = survivalstan.utils.extract_time_betas([testfit], coefs=['sex[T.male]']) 2 first_beta.head() ~/projects/survivalstan/survivalstan/utils.py in extract_time_betas(models, element, value_name, **kwargs) 90 element=element, 91 value_name=value_name, **kwargs) ---> 92 for model in models] 93 return pd.concat(data) 94 ~/projects/survivalstan/survivalstan/utils.py in <listcomp>(.0) 90 element=element, 91 value_name=value_name, **kwargs) ---> 92 for model in models] 93 return pd.concat(data) 94 ~/projects/survivalstan/survivalstan/utils.py in _extract_time_betas_single_model(stanmodel, element, coefs, value_name, timepoint_id_col, timepoint_end_col) 139 tb_df['coef'] = coef_names[i] 140 time_data.append(tb_df) --> 141 time_data = pd.concat(time_data) 142 timepoint_data = (stanmodel['df'] 143 .loc[:, [timepoint_id_col, timepoint_end_col]] /srv/conda/lib/python3.6/site-packages/pandas/core/reshape/concat.py in concat(objs, axis, join, join_axes, ignore_index, keys, levels, names, verify_integrity, sort, copy) 223 keys=keys, levels=levels, names=names, 224 verify_integrity=verify_integrity, --> 225 copy=copy, sort=sort) 226 return op.get_result() 227 /srv/conda/lib/python3.6/site-packages/pandas/core/reshape/concat.py in __init__(self, objs, axis, join, join_axes, keys, levels, names, ignore_index, verify_integrity, copy, sort) 257 258 if len(objs) == 0: --> 259 raise ValueError('No objects to concatenate') 260 261 if keys is None: ValueError: No objects to concatenate
import seaborn as sns
sns.boxplot(data=first_beta, x='timepoint_id', y='beta')
<matplotlib.axes._subplots.AxesSubplot at 0x7f6f0d2f3518>
survivalstan.utils.plot_time_betas(models=[testfit], y='beta', x='end_time', coefs=['sex[T.male]'])
Note that this same plot can be produced by passing data to plot_time_betas
directly.
survivalstan.utils.plot_time_betas(df=first_beta, by=['coef'], y='beta', x='end_time')