The MPyC demo kmsurival.py implements privacy-preserving Kaplan-Meier survival analysis, based on earlier work by Meilof Veeningen. The demo is built using the Python package lifelines, which provides extensive support for survival analysis and includes several datasets. We use lifelines for plotting Kaplan-Meier survival curves, performing logrank tests to compare survival curves, and printing survival tables.
# use pip (or, conda) to make sure lifelines package is installed:
!pip -q install lifelines
import os, functools
import pandas as pd
import matplotlib.pyplot as plt
import lifelines.statistics
from mpyc.runtime import mpc
mpc.logging(False)
from kmsurvival import fit_plot, events_to_table, events_from_table, logrank_test, aggregate, agg_logrank_test
Actual use of the lifelines package is hidden mostly inside the kmsurival.py demo, except for the function lifelines.statistics.logrank_test()
used below to validate the results of the secure computations.
We analyze the aml ("acute myelogenous leukemia") dataset, which is also used as a small example on Wikipedia. The file aml.csv
(copied from the R package KMsurv
) contains the raw data for 23 patients. Status 1 stands for the event "recurrence of aml cancer" and status 0 means no event ("censored").
df = pd.read_csv(os.path.join('data', 'surv', 'aml.csv')).rename(columns={'Unnamed: 0': 'observation', 'cens': 'status'})
df.sort_values(['time', 'observation']).style.hide_index()
Time is in weeks. The study compares the time until recurrence among two groups of patients. Patients in group 1 received maintenance chemotherapy, while patients in group 2 did not get any maintenance treatment. To plot the Kaplan–Meier survival curve for groups 1 and 2, we use the function fit_plot()
as follows.
T1, T2 = df['time'] [df['group'] == 1], df['time'] [df['group'] == 2]
E1, E2 = df['status'][df['group'] == 1], df['status'][df['group'] == 2]
kmf1, kmf2 = fit_plot(T1, T2, E1, E2, 'Kaplan-Meier survival curves', 'weeks', '1=Maintained', '2=Not maintained')
plt.show()
The vertical ticks on the graphs indicate censored events. At the bottom, the number of patients "at risk" is shown. The study concerned 11 patients in group 1 with one patient censored after 161 weeks and 12 patients in group 2. A Kaplan-Meier curve estimates the survival probability as a function of time (duration).
In this example, the two curves do not appear to differ very much. To analyze this more precisely one performs a logrank test:
lifelines.statistics.logrank_test(T1, T2, E1, E2).p_value
The null hypothesis is that survival is the same for both groups. Since $p=0.065$ is not particularly small (e.g., not below $\alpha=0.05$), the null hypothesis is not strongly rejected. In other words, the logrank test also says that the curves do not differ significantly---with the usual caveat of small sample sizes.
In a multiparty setting, each party holds a private dataset and the goal is to perform a survival analysis on the union of all the datasets. To obtain the secure union of the datasets in an efficient way, each private dataset will be represented using two survival tables, one survival table per group of patients.
The survival tables for the aml dataset are available from the Kaplan-Meier fitters kmf1
and kmf2
output by the call to fit_plot()
above:
display(kmf1.event_table, kmf2.event_table)
Using mpc.input()
, each party will secret-share its pair of survival tables with all the other parties. To allow for a simple union of the survival tables, however, we modify the representation of the survival tables as follows.
First, we determine when the last event happens, which is at $t=161$ in the example. We compute $maxT$ securely as the maximum over all parties, using the following MPyC one-liner:
secfxp = mpc.SecFxp(64) # logrank tests will require fixed-point arithmetic
maxT = int(await mpc.output(mpc.max(mpc.input(secfxp(int(df['time'].max()))))))
timeline = range(1, maxT+1)
max(timeline)
All parties use $1..maxT$ as the global timeline. Subsequently, each party pads its own survival tables to cover the entire timeline. This is accomplished by two calls to the function events_to_table()
, which only keeps the essential information:
d1, n1 = events_to_table(maxT, T1, E1)
d2, n2 = events_to_table(maxT, T2, E2)
pd.DataFrame({'d1': d1, 'n1': n1, 'd2': d2, 'n2': n2}, index=timeline).head(10)
Column d1
records the number of observed events ("deaths") for group 1 on the entire timeline, and column n1
records the number of patients "at risk". Similarly, for group 2.
To obtain the secure union of the privately held datasets, we let all parties secret-share their survival tables and simply add all of these elementwise:
d1, n1, d2, n2 = (functools.reduce(mpc.vector_add, mpc.input(list(map(secfxp, _)))) for _ in (d1, n1, d2, n2))
The joint dataset (union) is now represented by d1, n1, d2, n2
. Note that these values are all secret-shared, for example:
d1[0]
We now proceed to analyze the joint dataset in a privacy-preserving way by plotting aggregated versions of the Kaplan-Meier curves. The exact Kaplan-Meier curves would reveal too much information about the patients in the study. By aggregating the events over longer time intervals, the amount of information revealed by the curves is reduced. At the same time, however, the aggregated curves may still be helpful to see the overall results for the study---and in any case to serve as a sanity check.
The function aggregate()
securely adds all the observed ("death") events over intervals of a given length (stride). The aggregated values are all output publicly, and used to plot the curves via the function events_from_table()
:
stride = 16
agg_d1, agg_n1 = aggregate(d1, n1, stride)
agg_d2, agg_n2 = aggregate(d2, n2, stride)
agg_d1, agg_n1, agg_d2, agg_n2 = [list(map(int, await mpc.output(_))) for _ in (agg_d1, agg_n1, agg_d2, agg_n2)]
T1_, E1_ = events_from_table(agg_d1, agg_n1)
T2_, E2_ = events_from_table(agg_d2, agg_n2)
T1_, T2_ = [t * stride for t in T1_], [t * stride for t in T2_]
fit_plot(T1_, T2_, E1_, E2_, 'Aggregated Kaplan-Meier survival curves', 'weeks', '1=Maintained', '2=Not maintained')
plt.show()
Picking stride = 16
achieves a reasonable balance between privacy and utility. To enhance both privacy and utility at the same time, one may look for differentially private randomization techniques, adding a suitable type of noise to the Kaplan-Meier curves.
The function logrank_test()
performs a secure logrank test on a secret-shared dataset, similar to function lifelines.statistics.logrank_test()
used above for a dataset in the clear. The input parameter secfxp
specifies the secure type to be used for fixed-point arithmetic, and the output is an instance of lifelines.statistics.StatisticalResult
:
print((await logrank_test(secfxp, d1, d2, n1, n2)).p_value)
Relying solely on p-values is in general not a good idea, and this is especially true when handling otherwise (mostly) hidden data. Together with the aggregated curves, however, the p-value may lead to a useful conclusion for a study.
The function logrank_test()
uses one secure fixed-point division per time moment in $1..maxT$. Even though these divisions can all be done in parallel, the total effort is significant when $maxT$ is large. However, "most of the time" there is actually no event happening and no divisions need to be performed for these time moments. E.g., in the survival tables for the aml dataset above, there are only 7 time moments with nonzero d1
entries on the entire timeline $1..161$, and only 9 time moments with nonzero d2
entries.
Therefore, it may be advantageous to first extract the nonzero rows of the survival tables, and then limit the computation of the logrank test to those rows. The extraction needs to be done obliviously, not leaking any information about (the location of) the nonzero entries of the survival tables. To prevent this oblivious extraction step from becoming a bottleneck, however, we will actually exploit the fact that the aggregate curves are revealed anyway. We may simply use agg_d1
and agg_d2
to bound the number of events per stride, and extract the nonzero rows obliviously and efficiently for each stride.
This is basically what the function agg_logrank_test()
does:
print((await agg_logrank_test(secfxp, d1, d2, n1, n2, agg_d1, agg_d2, stride)).p_value)
Even for a small dataset like aml, the speedup is already noticeable. For larger datasets, the speedup gets really substantial, as can be noticed for some of the other datasets included with the kmsurival.py demo.
We end with two complete runs of the demo on the aml dataset, showing the Chi2 test statistic and p-value for each logrank test.
The help message included with the demo shows the command line options:
!python kmsurvival.py -h
To show the plots of the survival curves the main()
function of the demo is called directly from a notebook cell:
import sys
from kmsurvival import main
sys.argv[1:] = ['-i2', '--plot-curves']
await main()
To run the demo with three parties on localhost, for instance, we add -M3
as command line option and run kmsurival.py outside this notebook using a shell command. The plots are not shown this way, so instead we print the survival tables:
!python kmsurvival.py -M3 -i2 --print-tables --collapse
To try out other runs of the demo for yourself, remember to consult MPyC's help message, using the -H
option:
!python kmsurvival.py -H