#!/usr/bin/env python # coding: utf-8 # In[1]: get_ipython().run_line_magic('load_ext', 'autoreload') get_ipython().run_line_magic('autoreload', '2') get_ipython().run_line_magic('matplotlib', 'inline') import random random.seed(1100038344) import survivalstan import numpy as np import pandas as pd from stancache import stancache from matplotlib import pyplot as plt # In[2]: model_code = survivalstan.models.pem_survival_model_randomwalk # In[3]: print(model_code) # In[4]: d = stancache.cached( survivalstan.sim.sim_data_exp_correlated, N=100, censor_time=20, rate_form='1 + sex', rate_coefs=[-3, 0.5], ) d['age_centered'] = d['age'] - d['age'].mean() d.head() # In[5]: survivalstan.utils.plot_observed_survival(df=d[d['sex']=='female'], event_col='event', time_col='t', label='female') survivalstan.utils.plot_observed_survival(df=d[d['sex']=='male'], event_col='event', time_col='t', label='male') plt.legend() # In[6]: dlong = stancache.cached( survivalstan.prep_data_long_surv, df=d, event_col='event', time_col='t' ) # In[7]: dlong.head() # In[8]: testfit = survivalstan.fit_stan_survival_model( model_cohort = 'test model', model_code = model_code, df = dlong, sample_col = 'index', timepoint_end_col = 'end_time', event_col = 'end_failure', formula = '~ age_centered + sex', iter = 5000, chains = 4, seed = 9001, FIT_FUN = stancache.cached_stan_fit, ) # In[9]: survivalstan.utils.print_stan_summary([testfit], pars='lp__') # In[10]: survivalstan.utils.print_stan_summary([testfit], pars='log_baseline_raw') # In[11]: survivalstan.utils.plot_stan_summary([testfit], pars='log_baseline_raw') # In[12]: survivalstan.utils.plot_coefs([testfit], element='baseline') # In[13]: survivalstan.utils.plot_coefs([testfit]) # In[14]: survivalstan.utils.plot_pp_survival([testfit], fill=False) survivalstan.utils.plot_observed_survival(df=d, event_col='event', time_col='t', color='green', label='observed') plt.legend() # In[15]: survivalstan.utils.plot_pp_survival([testfit], by='sex') # In[ ]: # In[ ]: