#!/usr/bin/env python # coding: utf-8 # # Secure Kaplan-Meier Survival Analysis Explained # # The MPyC demo [kmsurival.py](kmsurvival.py) implements privacy-preserving Kaplan-Meier [survival analysis](https://en.wikipedia.org/wiki/Survival_analysis), based on earlier work by Meilof Veeningen. The demo is built using the Python package [lifelines](https://pypi.org/project/lifelines/), which provides extensive support for survival # analysis and includes several datasets. We use lifelines for plotting Kaplan-Meier survival curves, performing logrank tests to compare survival curves, and printing survival tables. # In[1]: # use pip (or, conda) to make sure lifelines package is installed: get_ipython().system('pip -q install lifelines') # In[2]: import os, functools import pandas as pd import matplotlib.pyplot as plt import lifelines.statistics from mpyc.runtime import mpc mpc.logging(False) from lifelines import KaplanMeierFitter from kmsurvival import plot_fits, events_to_table, events_from_table, logrank_test, aggregate, agg_logrank_test # Actual use of the lifelines package is hidden mostly inside the [kmsurival.py](kmsurvival.py) demo, except for the function `lifelines.statistics.logrank_test()` used below to validate the results of the secure computations. # # ## Kaplan-Meier Analysis # # We analyze the aml ("acute myelogenous leukemia") dataset, which is also used as a [small example on Wikipedia](https://en.wikipedia.org/wiki/Survival_analysis#Example:_Acute_myelogenous_leukemia_survival_data). The file `aml.csv` (copied from the R package `KMsurv`) contains the raw data for 23 patients. Status 1 stands for the event "recurrence of aml cancer" and status 0 means no event ("censored"). # In[3]: df = pd.read_csv(os.path.join('data', 'surv', 'aml.csv')).rename(columns={'Unnamed: 0': 'observation', 'cens': 'status'}) df.sort_values(['time', 'observation']).style.hide(axis='index') # Time is in weeks. The study compares the time until recurrence among two groups of patients. Patients in group 1 received maintenance chemotherapy, while patients in group 2 did not get any maintenance treatment. To plot the [Kaplan–Meier survival curve](https://en.wikipedia.org/wiki/Kaplan–Meier_estimator) for groups 1 and 2, we use the function `plot_fits()` as follows. # In[4]: T1, T2 = df['time'] [df['group'] == 1], df['time'] [df['group'] == 2] E1, E2 = df['status'][df['group'] == 1], df['status'][df['group'] == 2] kmf1 = KaplanMeierFitter(label='1=Maintained').fit(T1, E1) kmf2 = KaplanMeierFitter(label='2=Not maintained').fit(T2, E2) plot_fits(kmf1, kmf2, 'Kaplan-Meier survival curves', 'weeks') plt.show() # The vertical ticks on the graphs indicate censored events. At the bottom, the number of patients "at risk" is shown. The study concerned 11 patients in group 1 with one patient censored after 161 weeks and 12 patients in group 2. A Kaplan-Meier curve estimates the survival probability as a function of time (duration). # # In this example, the two curves do not appear to differ very much. To analyze this more precisely one performs a [logrank test](https://en.wikipedia.org/wiki/Logrank_test): # In[5]: lifelines.statistics.logrank_test(T1, T2, E1, E2).p_value # The null hypothesis is that survival is the same for both groups. Since $p=0.065$ is not particularly small (e.g., not below $\alpha=0.05$), the null hypothesis is not strongly rejected. In other words, the logrank test also says that the curves do not differ significantly---with the usual caveat of small sample sizes. # ## Privacy-Preserving Survival Analysis # In a multiparty setting, each party holds a private dataset and the goal is to perform a survival analysis on the *union* of all the datasets. To obtain the secure union of the datasets in an efficient way, each private dataset will be represented using two survival tables, one survival table per group of patients. # # The survival tables for the aml dataset are available from the Kaplan-Meier fitters `kmf1` and `kmf2`: # In[6]: display(kmf1.event_table, kmf2.event_table) # Using `mpc.input()`, each party will secret-share its pair of survival tables with all the other parties. To allow for a simple union of the survival tables, however, we modify the representation of the survival tables as follows. # # First, we determine when the last event happens, which is at $t=161$ in the example. We compute $maxT$ securely as the maximum over all parties, using the following MPyC one-liner: # In[7]: secfxp = mpc.SecFxp(64) # logrank tests will require fixed-point arithmetic maxT = int(await mpc.output(mpc.max(mpc.input(secfxp(int(df['time'].max())))))) timeline = range(1, maxT+1) max(timeline) # All parties use $1..maxT$ as the global timeline. Subsequently, each party pads its own survival tables to cover the entire timeline. This is accomplished by two calls to the function `events_to_table()`, which only keeps the essential information: # In[8]: d1, n1 = events_to_table(maxT, T1, E1) d2, n2 = events_to_table(maxT, T2, E2) pd.DataFrame({'d1': d1, 'n1': n1, 'd2': d2, 'n2': n2}, index=timeline).head(10) # Column `d1` records the number of observed events ("deaths") for group 1 on the entire timeline, and column `n1` records the number of patients "at risk". Similarly, for group 2. # # To obtain the secure union of the privately held datasets, we let all parties secret-share their survival tables and simply add all of these elementwise: # In[9]: d1, n1, d2, n2 = (functools.reduce(mpc.vector_add, mpc.input(list(map(secfxp, _)))) for _ in (d1, n1, d2, n2)) # The joint dataset (union) is now represented by `d1, n1, d2, n2`. Note that these values are all secret-shared, for example: # In[10]: d1[0] # ### Aggregate Kaplan-Meier Curves # # We now proceed to analyze the joint dataset in a privacy-preserving way by plotting *aggregated* versions of the Kaplan-Meier curves. The exact Kaplan-Meier curves would reveal too much information about the patients in the study. By aggregating the events over longer time intervals, the amount of information revealed by the curves is reduced. At the same time, however, the aggregated curves may still be helpful to see the overall results for the study---and in any case to serve as a sanity check. # # The function `aggregate()` securely adds all the observed ("death") events over intervals of a given length (stride). The aggregated values are all output publicly, and used to plot the curves via the function `events_from_table()`: # In[11]: stride = 16 agg_d1, agg_n1 = aggregate(d1, n1, stride) agg_d2, agg_n2 = aggregate(d2, n2, stride) agg_d1, agg_n1, agg_d2, agg_n2 = [list(map(int, await mpc.output(_))) for _ in (agg_d1, agg_n1, agg_d2, agg_n2)] T1_, E1_ = events_from_table(agg_d1, agg_n1) T2_, E2_ = events_from_table(agg_d2, agg_n2) T1_, T2_ = [t * stride for t in T1_], [t * stride for t in T2_] kmf1_ = KaplanMeierFitter(label='1=Maintained').fit(T1_, E1_) kmf2_ = KaplanMeierFitter(label='2=Not maintained').fit(T2_, E2_) plot_fits(kmf1_, kmf2_, 'Aggregated Kaplan-Meier survival curves', 'weeks') plt.show() # Picking `stride = 16` achieves a reasonable balance between privacy and utility. To enhance both privacy and utility at the same time, one may look for differentially private randomization techniques, adding a suitable type of noise to the Kaplan-Meier curves. # ### Secure Logrank Tests # # The function `logrank_test()` performs a secure logrank test on a secret-shared dataset, similar to function `lifelines.statistics.logrank_test()` used above for a dataset in the clear. The input parameter `secfxp` specifies the secure type to be used for fixed-point arithmetic, and the output is an instance of `lifelines.statistics.StatisticalResult`: # In[12]: print((await logrank_test(secfxp, d1, d2, n1, n2)).p_value) # Relying solely on p-values is in general not a good idea, and this is especially true when handling otherwise (mostly) hidden data. Together with the aggregated curves, however, the p-value may lead to a useful conclusion for a study. # # The function `logrank_test()` uses one secure fixed-point division per time moment in $1..maxT$. Even though these divisions can all be done in parallel, the total effort is significant when $maxT$ is large. However, "most of the time" there is actually no event happening and no divisions need to be performed for these time moments. E.g., in the survival tables for the aml dataset above, there are only 7 time moments with nonzero `d1` entries on the entire timeline $1..161$, and only 9 time moments with nonzero `d2` entries. # # Therefore, it may be advantageous to first extract the nonzero rows of the survival tables, and then limit the computation of the logrank test to those rows. The extraction needs to be done obliviously, not leaking any information about (the location of) the nonzero entries of the survival tables. To prevent this oblivious extraction step from becoming a bottleneck, however, we will actually exploit the fact that the aggregate curves are revealed anyway. We may simply use `agg_d1` and `agg_d2` to bound the number of events per stride, and extract the nonzero rows obliviously and efficiently for each stride. # This is basically what the function `agg_logrank_test()` does: # In[13]: print((await agg_logrank_test(secfxp, d1, d2, n1, n2, agg_d1, agg_d2, stride)).p_value) # Even for a small dataset like aml, the speedup is already noticeable. For larger datasets, the speedup gets really substantial, as can be noticed for some of the other datasets included with the [kmsurival.py](kmsurvival.py) demo. # ## Summary # We end with two complete runs of the demo on the aml dataset, showing the Chi2 test statistic and p-value for each logrank test. # # The help message included with the demo shows the command line options: # In[14]: get_ipython().system('python kmsurvival.py -h') # ### Complete Run: 5 logrank tests + survival curves # To show the plots of the survival curves the `main()` function of the demo is called directly from a notebook cell: # In[15]: import sys from kmsurvival import main sys.argv[1:] = ['-i2', '--plot-curves'] await main() # ### Complete Run: 5 logrank tests + survival tables # To run the demo with three parties on localhost, for instance, we add `-M3` as command line option and run [kmsurival.py](kmsurvival.py) outside this notebook using a shell command. The plots are not shown this way, so instead we print the survival tables: # In[17]: get_ipython().system('python kmsurvival.py -M3 -i2 --print-tables --collapse') # To try out other runs of the demo for yourself, remember to consult MPyC's help message, using the `-H` option: # In[18]: get_ipython().system('python kmsurvival.py -H')