%load_ext autoreload
%autoreload 2
%matplotlib inline
import random
random.seed(13847942484)
import survivalstan
import numpy as np
import pandas as pd
from stancache import stancache
from matplotlib import pyplot as plt
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/Cython/Distutils/old_build_ext.py:30: UserWarning: Cython.Distutils.old_build_ext does not properly handle dependencies and is deprecated. "Cython.Distutils.old_build_ext does not properly handle dependencies " /home/jacquelineburos/.local/lib/python3.5/site-packages/IPython/html.py:14: ShimWarning: The `IPython.html` package has been deprecated. You should import from `notebook` instead. `IPython.html.widgets` has moved to `ipywidgets`. "`IPython.html.widgets` has moved to `ipywidgets`.", ShimWarning) INFO:stancache.seed:Setting seed to 1245502385
print(survivalstan.models.pem_survival_model_timevarying)
/* Variable naming: // dimensions N = total number of observations (length of data) S = number of sample ids T = max timepoint (number of timepoint ids) M = number of covariates // data s = sample id for each obs t = timepoint id for each obs event = integer indicating if there was an event at time t for sample s x = matrix of real-valued covariates at time t for sample n [N, X] obs_t = observed end time for interval for timepoint for that obs */ // Jacqueline Buros Novik <jackinovik@gmail.com> functions { matrix spline(vector x, int N, int H, vector xi, int P) { matrix[N, H + P] b_x; // expanded predictors for (n in 1:N) { for (p in 1:P) { b_x[n,p] <- pow(x[n],p-1); // x[n]^(p-1) } for (h in 1:H) b_x[n, h + P] <- fmax(0, pow(x[n] - xi[h],P-1)); } return b_x; } } data { // dimensions int<lower=1> N; int<lower=1> S; int<lower=1> T; int<lower=0> M; // data matrix int<lower=1, upper=N> s[N]; // sample id int<lower=1, upper=T> t[N]; // timepoint id int<lower=0, upper=1> event[N]; // 1: event, 0:censor matrix[N, M] x; // explanatory vars // timepoint data vector<lower=0>[T] t_obs; vector<lower=0>[T] t_dur; } transformed data { vector[T] log_t_dur; int n_trans[S, T]; log_t_dur = log(t_obs); // n_trans used to map each sample*timepoint to n (used in gen quantities) // map each patient/timepoint combination to n values for (n in 1:N) { n_trans[s[n], t[n]] = n; } // fill in missing values with n for max t for that patient // ie assume "last observed" state applies forward (may be problematic for TVC) // this allows us to predict failure times >= observed survival times for (samp in 1:S) { int last_value; last_value = 0; for (tp in 1:T) { // manual says ints are initialized to neg values // so <=0 is a shorthand for "unassigned" if (n_trans[samp, tp] <= 0 && last_value != 0) { n_trans[samp, tp] = last_value; } else { last_value = n_trans[samp, tp]; } } } } parameters { vector[T] log_baseline_raw; // unstructured baseline hazard for each timepoint t real<lower=0> baseline_sigma; real log_baseline_mu; vector[M] beta; // beta-intercept vector<lower=0>[M] beta_time_sigma; vector[T-1] raw_beta_time_deltas[M]; // for each coefficient // change in coefficient value from previous time } transformed parameters { vector[N] log_hazard; vector[T] log_baseline; vector[T] beta_time[M]; vector[T] beta_time_deltas[M]; // adjust baseline hazard for duration of each period log_baseline = log_baseline_raw + log_t_dur; // compute timepoint-specific betas // offsets from previous time for (coef in 1:M) { beta_time_deltas[coef][1] = 0; for (time in 2:T) { beta_time_deltas[coef][time] = raw_beta_time_deltas[coef][time-1]; } } // coefficients for each timepoint T for (coef in 1:M) { beta_time[coef] = beta[coef] + cumulative_sum(beta_time_deltas[coef]); } // compute log-hazard for each obs for (n in 1:N) { real log_linpred; log_linpred <- 0; for (coef in 1:M) { // for now, handle each coef separately // (to be sure we pull out the "right" beta..) log_linpred = log_linpred + x[n, coef] * beta_time[coef][t[n]]; } log_hazard[n] = log_baseline_mu + log_baseline[t[n]] + log_linpred; } } model { // priors on time-varying coefficients for (m in 1:M) { raw_beta_time_deltas[m][1] ~ normal(0, 100); for(i in 2:(T-1)){ raw_beta_time_deltas[m][i] ~ normal(raw_beta_time_deltas[m][i-1], beta_time_sigma[m]); } } beta_time_sigma ~ cauchy(0, 1); beta ~ cauchy(0, 1); // priors on baseline hazard log_baseline_mu ~ normal(0, 1); baseline_sigma ~ normal(0, 1); log_baseline_raw[1] ~ normal(0, 1); for (i in 2:T) { log_baseline_raw[i] ~ normal(log_baseline_raw[i-1], baseline_sigma); } // model event ~ poisson_log(log_hazard); } generated quantities { real log_lik[N]; vector[T] baseline; int y_hat_mat[S, T]; // ppcheck for each S*T combination real y_hat_time[S]; // predicted failure time for each sample int y_hat_event[S]; // predicted event (0:censor, 1:event) // compute raw baseline hazard, for summary/plotting baseline = exp(log_baseline_raw); // log_likelihood for loo-psis for (n in 1:N) { log_lik[n] <- poisson_log_lpmf(event[n] | log_hazard[n]); } // posterior predicted values for (samp in 1:S) { int sample_alive; sample_alive = 1; for (tp in 1:T) { if (sample_alive == 1) { int n; int pred_y; real log_linpred; real log_haz; // determine predicted value of y n = n_trans[samp, tp]; // (borrow code from above to calc linpred) // but use sim tp not t[n] log_linpred = 0; for (coef in 1:M) { // for now, handle each coef separately // (to be sure we pull out the "right" beta..) log_linpred = log_linpred + x[n, coef] * beta_time[coef][tp]; } log_haz = log_baseline_mu + log_baseline[tp] + log_linpred; // now, make posterior prediction if (log_haz < log(pow(2, 30))) pred_y = poisson_log_rng(log_haz); else pred_y = 9; // mark this patient as ineligible for future tps // note: deliberately make 9s ineligible if (pred_y >= 1) { sample_alive = 0; y_hat_time[samp] = t_obs[tp]; y_hat_event[samp] = 1; } // save predicted value of y to matrix y_hat_mat[samp, tp] = pred_y; } else if (sample_alive == 0) { y_hat_mat[samp, tp] = 9; } } // end per-timepoint loop // if patient still alive at max // if (sample_alive == 1) { y_hat_time[samp] = t_obs[T]; y_hat_event[samp] = 0; } } // end per-sample loop }
d = stancache.cached(
survivalstan.sim.sim_data_exp_correlated,
N=100,
censor_time=20,
rate_form='1 + sex',
rate_coefs=[-3, 0.5],
)
d['age_centered'] = d['age'] - d['age'].mean()
d.head()
INFO:stancache.stancache:sim_data_exp_correlated: cache_filename set to sim_data_exp_correlated.cached.N_100.censor_time_20.rate_coefs_54462717316.rate_form_1 + sex.pkl INFO:stancache.stancache:sim_data_exp_correlated: Loading result from cache
age | sex | rate | true_t | t | event | index | age_centered | |
---|---|---|---|---|---|---|---|---|
0 | 59 | male | 0.082085 | 20.948771 | 20.000000 | False | 0 | 4.18 |
1 | 58 | male | 0.082085 | 12.827519 | 12.827519 | True | 1 | 3.18 |
2 | 61 | female | 0.049787 | 27.018886 | 20.000000 | False | 2 | 6.18 |
3 | 57 | female | 0.049787 | 62.220296 | 20.000000 | False | 3 | 2.18 |
4 | 55 | male | 0.082085 | 10.462045 | 10.462045 | True | 4 | 0.18 |
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='female'], event_col='event', time_col='t', label='female')
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='male'], event_col='event', time_col='t', label='male')
plt.legend()
<matplotlib.legend.Legend at 0x7fb845b45c88>
dlong = stancache.cached(
survivalstan.prep_data_long_surv,
df=d, event_col='event', time_col='t'
)
INFO:stancache.stancache:prep_data_long_surv: cache_filename set to prep_data_long_surv.cached.df_33772694934.event_col_event.time_col_t.pkl INFO:stancache.stancache:prep_data_long_surv: Loading result from cache
dlong.head()
age | sex | rate | true_t | t | event | index | age_centered | key | end_time | end_failure | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 59 | male | 0.082085 | 20.948771 | 20.0 | False | 0 | 4.18 | 1 | 20.000000 | False |
1 | 59 | male | 0.082085 | 20.948771 | 20.0 | False | 0 | 4.18 | 1 | 12.827519 | False |
2 | 59 | male | 0.082085 | 20.948771 | 20.0 | False | 0 | 4.18 | 1 | 10.462045 | False |
3 | 59 | male | 0.082085 | 20.948771 | 20.0 | False | 0 | 4.18 | 1 | 0.196923 | False |
4 | 59 | male | 0.082085 | 20.948771 | 20.0 | False | 0 | 4.18 | 1 | 9.244121 | False |
testfit = survivalstan.fit_stan_survival_model(
model_cohort = 'test model',
model_code = survivalstan.models.pem_survival_model_timevarying,
df = dlong,
sample_col = 'index',
timepoint_end_col = 'end_time',
event_col = 'end_failure',
formula = '~ age_centered + sex',
iter = 10000,
chains = 4,
seed = 9001,
FIT_FUN = stancache.cached_stan_fit,
)
INFO:stancache.stancache:Step 1: Get compiled model code, possibly from cache INFO:stancache.stancache:StanModel: cache_filename set to anon_model.cython_0_25_1.model_code_91766436784.pystan_2_12_0_0.stanmodel.pkl INFO:stancache.stancache:StanModel: Loading result from cache INFO:stancache.stancache:Step 2: Get posterior draws from model, possibly from cache INFO:stancache.stancache:sampling: cache_filename set to anon_model.cython_0_25_1.model_code_91766436784.pystan_2_12_0_0.stanfit.chains_4.data_84680791400.iter_10000.seed_9001.pkl INFO:stancache.stancache:sampling: Starting execution INFO:stancache.stancache:sampling: Execution completed (1:54:17.931309 elapsed) INFO:stancache.stancache:sampling: Saving results to cache /home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stancache/stancache.py:251: UserWarning: Pickling fit objects is an experimental feature! The relevant StanModel instance must be pickled along with this fit object. When unpickling the StanModel must be unpickled first. pickle.dump(res, open(cache_filepath, 'wb'), pickle.HIGHEST_PROTOCOL) /home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/psis.py:228: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison elif sort == 'in-place': /home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/psis.py:246: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future bs /= 3 * x[sort[np.floor(n/4 + 0.5) - 1]] /home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/psis.py:262: RuntimeWarning: overflow encountered in exp np.exp(temp, out=temp)
survivalstan.utils.print_stan_summary([testfit], pars='lp__')
mean se_mean sd 2.5% 50% 97.5% Rhat lp__ 397.041918 12.774893 92.121064 234.208559 395.077305 612.481689 1.073435
survivalstan.utils.plot_stan_summary([testfit], pars='log_baseline')
survivalstan.utils.plot_coefs([testfit], element='baseline')
survivalstan.utils.plot_coefs([testfit])
survivalstan.utils.plot_pp_survival([testfit], fill=False)
survivalstan.utils.plot_observed_survival(df=d, event_col='event', time_col='t', color='green', label='observed')
plt.legend()
<matplotlib.legend.Legend at 0x7fb6a243d630>
survivalstan.utils.plot_pp_survival([testfit], by='sex')
survivalstan.utils.plot_pp_survival([testfit], by='sex', pal=['red', 'blue'])
testfit['time_beta'] = survivalstan.utils.extract_time_betas(testfit)
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-15-78d52ac9eb84> in <module>() ----> 1 testfit['time_beta'] = survivalstan.utils.extract_time_betas(testfit) /home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/survivalstan/utils.py in extract_time_betas(stanmodel, bins, element, value_name, timepoint_id_col, timepoint_end_col) 57 precision=1, 58 ) ---> 59 rename_log_timepoints = np.exp(time_betas['alt_log_timepoints'].cat.categories.str.extract(', ([\d\.]+)\]', expand=False).astype(float)).astype(int) 60 rename_timepoints = time_betas['alt_timepoints'].cat.categories.str.extract(', ([\d\.]+)\]', expand=False).astype(float).astype(int) 61 time_betas['alt_timepoint_end'] = time_betas['alt_timepoints'].cat.rename_categories(rename_timepoints) /home/jacquelineburos/.local/lib/python3.5/site-packages/pandas/indexes/numeric.py in astype(self, dtype, copy) 219 elif is_integer_dtype(dtype): 220 if self.hasnans: --> 221 raise ValueError('cannot convert float NaN to integer') 222 values = self._values.astype(dtype, copy=copy) 223 elif is_object_dtype(dtype): ValueError: cannot convert float NaN to integer