As it's popular counterparts for classification and regression, a Random Survival Forest is an ensemble of tree-based learners. A Random Survival Forest ensures that individual trees are de-correlated by 1) building each tree on a different bootstrap sample of the original training data, and 2) at each node, only evaluate the split criterion for a randomly selected subset of features and thresholds. Predictions are formed by aggregating predictions of individual trees in the ensemble.
To demonstrate Random Survival Forest, we are going to use data from the German Breast Cancer Study Group (GBSG-2) on the treatment of node-positive breast cancer patients. It contains data on 686 women and 8 prognostic factors:
The goal is to predict recurrence-free survival time.
import pandas as pd import matplotlib.pyplot as plt import numpy as np %matplotlib inline from sklearn.model_selection import train_test_split from sklearn.preprocessing import OrdinalEncoder from sksurv.datasets import load_gbsg2 from sksurv.preprocessing import OneHotEncoder from sksurv.ensemble import RandomSurvivalForest
First, we need to load the data and transform it into numeric values.
X, y = load_gbsg2() grade_str = X.loc[:, "tgrade"].astype(object).values[:, np.newaxis] grade_num = OrdinalEncoder(categories=[["I", "II", "III"]]).fit_transform(grade_str) X_no_grade = X.drop("tgrade", axis=1) Xt = OneHotEncoder().fit_transform(X_no_grade) Xt = np.column_stack((Xt.values, grade_num)) feature_names = X_no_grade.columns.tolist() + ["tgrade"]
Next, the data is split into 75% for training and 25% for testing so we can determine how well our model generalizes.
random_state = 20 X_train, X_test, y_train, y_test = train_test_split( Xt, y, test_size=0.25, random_state=random_state)
Several split criterion have been proposed in the past, but the most widespread one is based on the log-rank test, which you probably now from comparing survival curves among two or more groups. Using the training data, we fit a Random Survival Forest comprising 1000 trees.
rsf = RandomSurvivalForest(n_estimators=1000, min_samples_split=10, min_samples_leaf=15, max_features="sqrt", n_jobs=-1, random_state=random_state) rsf.fit(X_train, y_train)
RandomSurvivalForest(bootstrap=True, max_depth=None, max_features='sqrt', max_leaf_nodes=None, min_samples_leaf=15, min_samples_split=10, min_weight_fraction_leaf=0.0, n_estimators=1000, n_jobs=-1, oob_score=False, random_state=20, verbose=0, warm_start=False)
We can check how well the model performs by evaluating it on the test data.
This gives a concordance index of 0.68, which is a good a value and matches the results reported in the Random Survival Forests paper.
For prediction, a sample is dropped down each tree in the forest until it reaches a terminal node. Data in each terminal is used to non-parametrically estimate the survival and cumulative hazard function using the Kaplan-Meier and Nelson-Aalen estimator, respectively. In addition, a risk score can be computed that represents the expected number of events for one particular terminal node. The ensemble prediction is simply the average across all trees in the forest.
Let's first select a couple of patients from the test data according to the number of positive lymph nodes and age.
a = np.empty(X_test.shape, dtype=[("age", float), ("pnodes", float)]) a["age"] = X_test[:, 0] a["pnodes"] = X_test[:, 4] sort_idx = np.argsort(a, order=["pnodes", "age"]) X_test_sel = pd.DataFrame( X_test[np.concatenate((sort_idx[:3], sort_idx[-3:]))], columns=feature_names) X_test_sel
The predicted risk scores indicate that risk for the last three patients is quite a bit higher than that of the first three patients.
0 91.477609 1 102.897552 2 75.883786 3 170.502092 4 171.210066 5 148.691835 dtype: float64
We can have a more detailed insight by considering the predicted survival function. It shows that the biggest difference occurs roughly within the first 750 days.
surv = rsf.predict_survival_function(X_test_sel) for i, s in enumerate(surv): plt.step(rsf.event_times_, s, where="post", label=str(i)) plt.ylabel("Survival probability") plt.xlabel("Time in days") plt.grid(True) plt.legend()
<matplotlib.legend.Legend at 0x7f44b833bf28>
Alternatively, we can also plot the predicted cumulative hazard function.
surv = rsf.predict_cumulative_hazard_function(X_test_sel) for i, s in enumerate(surv): plt.step(rsf.event_times_, s, where="post", label=str(i)) plt.ylabel("Cumulative hazard") plt.xlabel("Time in days") plt.grid(True) plt.legend()
<matplotlib.legend.Legend at 0x7f44bab64ac8>
The implementation is based on scikit-learn's Random Forest implementation and inherits many
features, such as building trees in parallel. What's currently missing is feature importances
This is due to the way scikit-learn's implementation computes importances. It relies on
a measure of impurity for each child node, and defines importance as the amount of
decrease in impurity due to a split. For traditional regression, impurity would be measured by the variance, but for survival analysis
there is no per-node impurity measure due to censoring. Instead, one could use the
magnitude of the log-rank test statistic as an importance measure, but scikit-learn's
implementation doesn't seem to allow this.
Fortunately, this is not a big concern though, as scikit-learn's definition of feature importance is non-standard and differs from what Leo Breiman proposed in the original Random Forest paper. Instead, we can use permutation to estimate feature importance, which is preferred over scikit-learn's definition. This is implemented in the ELI5 library, which is fully compatible with scikit-survival.
import eli5 from eli5.sklearn import PermutationImportance perm = PermutationImportance(rsf, n_iter=15, random_state=random_state) perm.fit(X_test, y_test) eli5.show_weights(perm, feature_names=feature_names)
|0.0676 ± 0.0229||pnodes|
|0.0206 ± 0.0139||age|
|0.0177 ± 0.0468||progrec|
|0.0086 ± 0.0098||horTh|
|0.0032 ± 0.0198||tsize|
|0.0032 ± 0.0060||tgrade|
|-0.0007 ± 0.0018||menostat|
|-0.0063 ± 0.0207||estrec|
The result shows that the number of positive lymph nodes (
pnodes) is by far the most important
feature. If its relationship to survival time is removed (by random shuffling),
the concordance index on the test data drops on average by 0.0676 points.
Again, this agrees with the results from the original
Random Survival Forests paper.