Most machine learning algorithms have been developed to perform classification or regression. However, in clinical research we often want to estimate the time to and event, such as death or recurrence of cancer, which leads to a special type of learning task that is distinct from classification and regression. This task is termed survival analysis, but is also referred to as time-to-event analysis or reliability analysis. Many machine learning algorithms have been adopted to perform survival analysis: Support Vector Machines, Random Forest, or Boosting. It has only been recently that survival analysis entered the era of deep learning, which is the focus of this post.
You will learn how to train a convolutional neural network to predict time to a (generated) event from MNIST images, using a loss function specific to survival analysis. The first part, will cover some basic terms and quantities used in survival analysis (feel free to skip this part if you are already familiar). In the second part, we will generate synthetic survival data from MNIST images and visualize it. In the third part, we will briefly revisit the most popular survival model of them all and learn how it can be used as a loss function for training a neural network. Finally, we put all the pieces together and train a convolutional neural network on MNIST and predict survival functions on the test data.
Please make sure you have the following packages installed. All are available via PyPI or Anaconda.
You can also run this notebook in Google Colaboratory and install scikit-survival using the command below.
!pip uninstall --yes --quiet osqp
!pip install scikit-survival
from typing import Any, Dict, Iterable, Sequence, Tuple, Optional, Union
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from sksurv.nonparametric import kaplan_meier_estimator
from sksurv.metrics import concordance_index_censored
import tensorflow as tf
from tensorflow.keras.datasets import mnist
print("Using Tensorflow:", tf.__version__)
Using Tensorflow: 2.2.0
from distutils.version import LooseVersion
assert LooseVersion(tf.__version__) >= LooseVersion("2.0.0"), \
"This notebook requires TensorFlow 2.0 or above."
The objective in survival analysis is to establish a connection between covariates and the time of an event. The name survival analysis originates from clinical research, where predicting the time to death, i.e., survival, is often the main objective. Survival analysis is a type of regression problem (one wants to predict a continuous value), but with a twist. It differs from traditional regression by the fact that parts of the training data can only be partially observed – they are censored.
As an example, consider a clinical study that has been carried out over a 1 year period as in the figure below.
Patient A was lost to follow-up after three months with no recorded event, patient B experienced an event four and a half months after enrollment, patient C withdrew from the study two months after enrollment, and patient E did not experience any event before the study ended. Consequently, the exact time of an event could only be recorded for patients B and D; their records are uncensored. For the remaining patients it is unknown whether they did or did not experience an event after termination of the study. The only valid information that is available for patients A, C, and E is that they were event-free up to their last follow-up. Therefore, their records are censored.
Formally, each patient record consists of the time $t>0$ when an event occurred or the time $c>0$ of censoring. Since censoring and experiencing and event are mutually exclusive, it is common to define an event indicator $\delta \in \{0;1\}$ and the observable survival time $y>0$. The observable time $y$ of a right censored time of event is defined as
$$ y = \min(t, c) = \begin{cases} t & \text{if } \delta = 1 , \\ c & \text{if } \delta = 0 . \end{cases} $$Consequently, survival analysis demands for models that take partially observed, i.e., censored, event times into account.
Typically, the survival time is modelled as a continuous non-negative random variable $T$, from which basic quantities for time-to-event analysis can be derived, most importantly, the survival function and the hazard function.
Alternative names for the hazard function are conditional failure rate, conditional mortality rate, or instantaneous failure rate. In contrast to the survival function, which describes the absence of an event, the hazard function provides information about the occurrence of an event.
To start off, we are using images from the MNIST dataset and will synthetically generate survival times based on the digit each image represents. We associate a survival time (or risk score) with each class of the ten digits in MNIST. First, we randomly assign each class label to one of four overall risk groups, such that some digits will correspond to better and others to worse survival. Next, we generate risk scores that indicate how big the risk of experiencing an event is, relative to each other.
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype(np.float32) / 255.
x_test = x_test.astype(np.float32) / 255.
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)
y = np.concatenate((y_train, y_test))
def make_risk_score_for_groups(y: np.ndarray,
n_groups: int = 4,
seed: int = 89) -> Tuple[pd.DataFrame, np.ndarray]:
rnd = np.random.RandomState(seed)
# assign class labels `y` to one of `n_groups` risk groups
classes = np.unique(y)
group_assignment = {}
group_members = {}
groups = rnd.randint(n_groups, size=classes.shape)
for label, group in zip(classes, groups):
group_assignment[label] = group
group_members.setdefault(group, []).append(label)
# assign risk score to each class label in `y`
risk_per_class = {}
for label in classes:
group_idx = group_assignment[label]
group = group_members[group_idx]
label_idx = group.index(label)
group_size = len(group)
# allow risk scores in each group to vary slightly
risk_score = np.sqrt(group_idx + 1e-4) * 1.75
risk_score -= (label_idx - (group_size // 2)) / 25.
risk_per_class[label] = risk_score
assignment = pd.concat((
pd.Series(risk_per_class, name="risk_score"),
pd.Series(group_assignment, name="risk_group")
), axis=1).rename_axis("class_label")
risk_scores = np.array([risk_per_class[yy] for yy in y])
return assignment, risk_scores
risk_score_assignment, risk_scores = make_risk_score_for_groups(y)
risk_score_assignment.round(3)
risk_score | risk_group | |
---|---|---|
class_label | ||
0 | 3.071 | 3 |
1 | 2.555 | 2 |
2 | 0.058 | 0 |
3 | 1.790 | 1 |
4 | 2.515 | 2 |
5 | 3.031 | 3 |
6 | 1.750 | 1 |
7 | 2.475 | 2 |
8 | 0.018 | 0 |
9 | 2.435 | 2 |
We can see that class labels 2 and 8 belong to risk group 0, which has the lowest risk (close to zero). Risk group 1 corresponds to a risk score of about 1.7, risk group 2 of about 2.5, and risk group 3 is the group with the highest risk score of about 3.
To generate survival times from risk scores, we are going to follow the protocol of Bender et al. We choose the exponential distribution for the survival time. Its probability density function is $f(t\,|\,\lambda) = \lambda \exp(-\lambda t)$, where $\lambda > 0$ is a scale parameter that is the inverse of the expectation: $E(T) = \frac{1}{\lambda}$. The exponential distribution results in a relatively simple time-to-event model with no memory, because the hazard rate is constant: $h(t) = \lambda$. For more complex cases, refer to the paper by Bender et al.
Here, we choose $\lambda$ such that the mean survival time is 365 days. Finally, we randomly censor survival times drawing times of censoring from a uniform distribution such that we approximately obtain the desired amount of 45% censoring. The generated survival data comprises an observed time and a boolean event indicator for each MNIST image.
class SurvivalTimeGenerator:
def __init__(self,
num_samples: int,
mean_survival_time: float,
prob_censored: float) -> None:
self.num_samples = num_samples
self.mean_survival_time = mean_survival_time
self.prob_censored = prob_censored
def gen_censored_time(self,
risk_score: np.ndarray,
seed: int = 89) -> Tuple[np.ndarray,np.ndarray]:
rnd = np.random.RandomState(seed)
# generate survival time
baseline_hazard = 1. / self.mean_survival_time
scale = baseline_hazard * np.exp(risk_score)
u = rnd.uniform(low=0, high=1, size=risk_score.shape[0])
t = -np.log(u) / scale
# generate time of censoring
qt = np.quantile(t, 1.0 - self.prob_censored)
c = rnd.uniform(low=t.min(), high=qt)
# apply censoring
observed_event = t <= c
observed_time = np.where(observed_event, t, c)
return observed_time, observed_event
surv_gen = SurvivalTimeGenerator(
num_samples=y.shape[0],
mean_survival_time=365.,
prob_censored=.45
)
time, event = surv_gen.gen_censored_time(risk_scores)
time_train = time[:y_train.shape[0]]
event_train = event[:y_train.shape[0]]
time_test = time[y_train.shape[0]:]
event_test = event[y_train.shape[0]:]
print("%.2f%% samples are right censored in training data." % (np.sum(~event_train) * 100. / len(event_train)))
print("%.2f%% samples are right censored in test data." % (np.sum(~event_test) * 100. / len(event_test)))
46.19% samples are right censored in training data. 46.33% samples are right censored in test data.
We can use the generated censored data and estimate the survival function $S(t)$ to see what the risk scores actually mean in terms of survival. We stratify the training data by class label, and estimate the corresponding survival function using the non-parametric Kaplan-Meier estimator.
styles = ('-', '--', '-.', ':')
plt.figure(figsize=(6, 4.5))
for row in risk_score_assignment.itertuples():
mask = y_train == row.Index
coord_x, coord_y = kaplan_meier_estimator(event_train[mask], time_train[mask])
ls = styles[row.risk_group]
plt.step(coord_x, coord_y, where="post", label=f"Class {row.Index}", linestyle=ls)
plt.ylim(0, 1)
plt.ylabel("Probability of survival $P(T > t)$")
plt.xlabel("Time $t$")
plt.grid()
plt.legend()
<matplotlib.legend.Legend at 0x7f9e7c717c50>
Classes 0 and 5 (dotted lines) correspond to risk group 3, which has the highest risk score. The corresponding survival functions drop most quickly, which is exactly what we wanted. On the other end of the spectrum are classes 2 and 8 (solid lines) belonging to risk group 0 with the lowest risk.
One important aspect for survival analysis is that both the training data and the test data are subject to censoring, because we are unable to observe the exact time of an event no matter how the data was split. Therefore, performance measures need to account for censoring. The most widely used performance measure is Harrell's concordance index. Given a set of (predicted) risk scores and observed times, it checks whether the ordering by risk scores is concordant with the ordering by actual survival time. While Harrell's concordance index is widely used, it has its flaws, in particular when data is highly censored. Please refer to my previous post on evaluating survival models for more details.
We can take the risk score from which we generated survival times to check how good a model would perform if we knew the actual risk score.
cindex = concordance_index_censored(event_test, time_test, risk_scores[y_train.shape[0]:])
print(f"Concordance index on test data with actual risk scores: {cindex[0]:.3f}")
Concordance index on test data with actual risk scores: 0.705
Surprisingly, we do not obtain a perfect result of 1.0. The reason for this is that generated survival times are randomly distributed based on risk scores and not deterministic functions of the risk score. Therefore, any model we will train on this data should not be able to exceed this performance value.
By far the most widely used model to learn from censored survival data, is Cox's proportional hazards model model. It models the hazard function $h(t_i)$ of the $i$-th subject, conditional on the feature vector $\mathbf{x}_i \in \mathbb{R}^p$, as the product of an unspecified baseline hazard function $h_0$ (more on that later) and an exponential function of the linear model $\mathbf{x}_i^\top \mathbf{\beta}$: $$ h(t | x_{i1}, \ldots, x_{ip}) = h_0(t) \exp \left( \sum_{j=1}^p x_{ij} \beta_j \right) \Leftrightarrow \log \frac{h(t | \mathbf{x}_i)}{h_0 (t)} = \mathbf{x}_i^\top \mathbf{\beta} , $$ where $\mathbf{\beta} \in \mathbb{R}^p$ are the coefficients associated with each of the $p$ features, and no intercept term is included in the model. The key is that the hazard function is split into two parts: the baseline hazard function $h_0$ only depends on the time $t$, whereas the exponential is independent of time and only depends on the covariates $\mathbf{x}_i$.
Cox's proportional hazards model is fitted by maximizing the partial likelihood function, which is based on the probability that the $i$-th individual experiences an event at time $t_i$, given that there is one event at time point $t_i$. As we will see, by specifying the hazard function as above, the baseline hazard function $h_0$ can be eliminated and does not need be defined for finding the coefficients $\mathbf{\beta}$. Let $\mathcal{R}_i = \{ j\,|\,y_j \geq y_i \}$ be the risk set, i.e., the set of subjects who remained event-free shortly before time point $y_i$, and $I(\cdot)$ the indicator function, then we have
$$ \begin{split} &P(\text{subject experiences event at $y_i$} \mid \text{one event at $y_i$}) \\ =& \frac{P(\text{subject experiences event at $y_i$} \mid \text{event-free up to $y_i$})} {P (\text{one event at $y_i$} \mid \text{event-free up to $y_i$})} \\ =& \frac{h(y_i | \mathbf{x}_i)}{ \sum_{j=1}^n I(y_j \geq y_i) h(y_j | \mathbf{x}_j) } \\ =& \frac{h_0(y_i) \exp(\mathbf{x}_i^\top \mathbf{\beta})} { \sum_{j=1}^n I(y_j \geq y_i) h_0(y_j) \exp(\mathbf{x}_j^\top \mathbf{\beta}) } \\ =& \frac{\exp( \mathbf{x}_i^\top \beta)}{\sum_{j \in \mathcal{R}_i} \exp( \mathbf{x}_j^\top \beta)} . \end{split} $$By multiplying the conditional probability from above for all patients who experienced an event, and taking the logarithm, we obtain the partial likelihood function:
$$ \widehat{\mathbf{\beta}} = \arg\max_{\mathbf{\beta}}~ \log\,PL(\mathbf{\beta}) = \sum_{i=1}^n \delta_i \left[ \mathbf{x}_i^\top \mathbf{\beta} - \log \left( \sum_{j \in \mathcal{R}_i} \exp( \mathbf{x}_j^\top \mathbf{\beta}) \right) \right] . $$Cox's proportional hazards model as described above is a linear model, i.e., the predicted risk score is a linear combination of features. However, the model can easily be extended to the non-linear case by just replacing the linear predictor with the output of a neural network with parameters $\mathbf{\Theta}$.
This has been realized early on and was originally proposed in the work of Faraggi and Simon back in 1995. Farragi and Simon explore multilayer perceptrons, but the same loss can be used in combination with more advanced architectures such as convolutional neural networks or recurrent neural networks. Therefore, it is natural to also use the same loss function in the era of deep learning. However, this transition is not so easy as it may seem and comes with some caveats, both for training and for evaluation.
When implementing the Cox PH loss function, the problematic part is the inner sum over the risk set: $\sum_{j \in \mathcal{R}_i} \exp( \mathbf{x}_j^\top \mathbf{\beta})$. Note that the risk set is defined as $\mathcal{R}_i = \{ j\,|\,y_j \geq y_i \}$, which implies an ordering according to observed times $y_i$, which may lead to quadratic complexity if implemented naively. Ideally, we want to sort the data once in descending order by survival time and then incrementally update the inner sum, which leads to a linear complexity to compute the loss (ignoring the time for sorting).
Another problem is that the risk set for the subject with the smallest uncensored survival time is over the whole dataset. This is usually impractical, because we may not be able to keep the whole dataset in GPU memory. If we use mini-batches instead, as it's the norm, (i) we cannot compute the exact loss, because we may not have access to all samples in the risk set, and (ii) we need to sort each mini-batch by observed time, instead of sorting the whole data once.
For practical purposes, computing the Cox PH loss over a mini-batch is usually fine, as long as the batch contains several uncensored samples, because otherwise the outer sum in the partial likelihood function would be over an empty set. Here, we implement the sum over the risk set by multiplying the exponential of the predictions (as a row vector) by a squared boolean matrix that contains each sample's risk set as its rows. The sum over the risk set for each sample is then equivalent to a row-wise summation.
def _make_riskset(time: np.ndarray) -> np.ndarray:
"""Compute mask that represents each sample's risk set.
Parameters
----------
time : np.ndarray, shape=(n_samples,)
Observed event time sorted in descending order.
Returns
-------
risk_set : np.ndarray, shape=(n_samples, n_samples)
Boolean matrix where the `i`-th row denotes the
risk set of the `i`-th instance, i.e. the indices `j`
for which the observer time `y_j >= y_i`.
"""
assert time.ndim == 1, "expected 1D array"
# sort in descending order
o = np.argsort(-time, kind="mergesort")
n_samples = len(time)
risk_set = np.zeros((n_samples, n_samples), dtype=np.bool_)
for i_org, i_sort in enumerate(o):
ti = time[i_sort]
k = i_org
while k < n_samples and ti == time[o[k]]:
k += 1
risk_set[i_sort, o[:k]] = True
return risk_set
class InputFunction:
"""Callable input function that computes the risk set for each batch.
Parameters
----------
images : np.ndarray, shape=(n_samples, height, width)
Image data.
time : np.ndarray, shape=(n_samples,)
Observed time.
event : np.ndarray, shape=(n_samples,)
Event indicator.
batch_size : int, optional, default=64
Number of samples per batch.
drop_last : int, optional, default=False
Whether to drop the last incomplete batch.
shuffle : bool, optional, default=False
Whether to shuffle data.
seed : int, optional, default=89
Random number seed.
"""
def __init__(self,
images: np.ndarray,
time: np.ndarray,
event: np.ndarray,
batch_size: int = 64,
drop_last: bool = False,
shuffle: bool = False,
seed: int = 89) -> None:
if images.ndim == 3:
images = images[..., np.newaxis]
self.images = images
self.time = time
self.event = event
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle = shuffle
self.seed = seed
def size(self) -> int:
"""Total number of samples."""
return self.images.shape[0]
def steps_per_epoch(self) -> int:
"""Number of batches for one epoch."""
return int(np.floor(self.size() / self.batch_size))
def _get_data_batch(self, index: np.ndarray) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
"""Compute risk set for samples in batch."""
time = self.time[index]
event = self.event[index]
images = self.images[index]
labels = {
"label_event": event.astype(np.int32),
"label_time": time.astype(np.float32),
"label_riskset": _make_riskset(time)
}
return images, labels
def _iter_data(self) -> Iterable[Tuple[np.ndarray, Dict[str, np.ndarray]]]:
"""Generator that yields one batch at a time."""
index = np.arange(self.size())
rnd = np.random.RandomState(self.seed)
if self.shuffle:
rnd.shuffle(index)
for b in range(self.steps_per_epoch()):
start = b * self.batch_size
idx = index[start:(start + self.batch_size)]
yield self._get_data_batch(idx)
if not self.drop_last:
start = self.steps_per_epoch() * self.batch_size
idx = index[start:]
yield self._get_data_batch(idx)
def _get_shapes(self) -> Tuple[tf.TensorShape, Dict[str, tf.TensorShape]]:
"""Return shapes of data returned by `self._iter_data`."""
batch_size = self.batch_size if self.drop_last else None
h, w, c = self.images.shape[1:]
images = tf.TensorShape([batch_size, h, w, c])
labels = {k: tf.TensorShape((batch_size,))
for k in ("label_event", "label_time")}
labels["label_riskset"] = tf.TensorShape((batch_size, batch_size))
return images, labels
def _get_dtypes(self) -> Tuple[tf.DType, Dict[str, tf.DType]]:
"""Return dtypes of data returned by `self._iter_data`."""
labels = {"label_event": tf.int32,
"label_time": tf.float32,
"label_riskset": tf.bool}
return tf.float32, labels
def _make_dataset(self) -> tf.data.Dataset:
"""Create dataset from generator."""
ds = tf.data.Dataset.from_generator(
self._iter_data,
self._get_dtypes(),
self._get_shapes()
)
return ds
def __call__(self) -> tf.data.Dataset:
return self._make_dataset()
def safe_normalize(x: tf.Tensor) -> tf.Tensor:
"""Normalize risk scores to avoid exp underflowing.
Note that only risk scores relative to each other matter.
If minimum risk score is negative, we shift scores so minimum
is at zero.
"""
x_min = tf.reduce_min(x, axis=0)
c = tf.zeros_like(x_min)
norm = tf.where(x_min < 0, -x_min, c)
return x + norm
def logsumexp_masked(risk_scores: tf.Tensor,
mask: tf.Tensor,
axis: int = 0,
keepdims: Optional[bool] = None) -> tf.Tensor:
"""Compute logsumexp across `axis` for entries where `mask` is true."""
risk_scores.shape.assert_same_rank(mask.shape)
with tf.name_scope("logsumexp_masked"):
mask_f = tf.cast(mask, risk_scores.dtype)
risk_scores_masked = tf.math.multiply(risk_scores, mask_f)
# for numerical stability, substract the maximum value
# before taking the exponential
amax = tf.reduce_max(risk_scores_masked, axis=axis, keepdims=True)
risk_scores_shift = risk_scores_masked - amax
exp_masked = tf.math.multiply(tf.exp(risk_scores_shift), mask_f)
exp_sum = tf.reduce_sum(exp_masked, axis=axis, keepdims=True)
output = amax + tf.math.log(exp_sum)
if not keepdims:
output = tf.squeeze(output, axis=axis)
return output
class CoxPHLoss(tf.keras.losses.Loss):
"""Negative partial log-likelihood of Cox's proportional hazards model."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self,
y_true: Sequence[tf.Tensor],
y_pred: tf.Tensor) -> tf.Tensor:
"""Compute loss.
Parameters
----------
y_true : list|tuple of tf.Tensor
The first element holds a binary vector where 1
indicates an event 0 censoring.
The second element holds the riskset, a
boolean matrix where the `i`-th row denotes the
risk set of the `i`-th instance, i.e. the indices `j`
for which the observer time `y_j >= y_i`.
Both must be rank 2 tensors.
y_pred : tf.Tensor
The predicted outputs. Must be a rank 2 tensor.
Returns
-------
loss : tf.Tensor
Loss for each instance in the batch.
"""
event, riskset = y_true
predictions = y_pred
pred_shape = predictions.shape
if pred_shape.ndims != 2:
raise ValueError("Rank mismatch: Rank of predictions (received %s) should "
"be 2." % pred_shape.ndims)
if pred_shape[1] is None:
raise ValueError("Last dimension of predictions must be known.")
if pred_shape[1] != 1:
raise ValueError("Dimension mismatch: Last dimension of predictions "
"(received %s) must be 1." % pred_shape[1])
if event.shape.ndims != pred_shape.ndims:
raise ValueError("Rank mismatch: Rank of predictions (received %s) should "
"equal rank of event (received %s)" % (
pred_shape.ndims, event.shape.ndims))
if riskset.shape.ndims != 2:
raise ValueError("Rank mismatch: Rank of riskset (received %s) should "
"be 2." % riskset.shape.ndims)
event = tf.cast(event, predictions.dtype)
predictions = safe_normalize(predictions)
with tf.name_scope("assertions"):
assertions = (
tf.debugging.assert_less_equal(event, 1.),
tf.debugging.assert_greater_equal(event, 0.),
tf.debugging.assert_type(riskset, tf.bool)
)
# move batch dimension to the end so predictions get broadcast
# row-wise when multiplying by riskset
pred_t = tf.transpose(predictions)
# compute log of sum over risk set for each row
rr = logsumexp_masked(pred_t, riskset, axis=1, keepdims=True)
assert rr.shape.as_list() == predictions.shape.as_list()
losses = tf.math.multiply(event, rr - predictions)
return losses
To monitor the training process, we would like to compute the concordance index with respect to a separate validation set. Similar to the Cox PH loss, the concordance index needs access to predicted risk scores and ground truth of all samples in the validation data. While we had to opt for computing the Cox PH loss over a mini-batch, I would not recommend this for the validation data. For small batch sizes and/or high amount of censoring, the estimated concordance index would be quite volatile, which makes it very hard to interpret. In addition, the validation data is usually considerably smaller than the training data, therefore we can collect predictions for the whole validation data and compute the concordance index accurately.
class CindexMetric:
"""Computes concordance index across one epoch."""
def reset_states(self) -> None:
"""Clear the buffer of collected values."""
self._data = {
"label_time": [],
"label_event": [],
"prediction": []
}
def update_state(self, y_true: Dict[str, tf.Tensor], y_pred: tf.Tensor) -> None:
"""Collect observed time, event indicator and predictions for a batch.
Parameters
----------
y_true : dict
Must have two items:
`label_time`, a tensor containing observed time for one batch,
and `label_event`, a tensor containing event indicator for one batch.
y_pred : tf.Tensor
Tensor containing predicted risk score for one batch.
"""
self._data["label_time"].append(y_true["label_time"].numpy())
self._data["label_event"].append(y_true["label_event"].numpy())
self._data["prediction"].append(tf.squeeze(y_pred).numpy())
def result(self) -> Dict[str, float]:
"""Computes the concordance index across collected values.
Returns
----------
metrics : dict
Computed metrics.
"""
data = {}
for k, v in self._data.items():
data[k] = np.concatenate(v)
results = concordance_index_censored(
data["label_event"] == 1,
data["label_time"],
data["prediction"])
result_data = {}
names = ("cindex", "concordant", "discordant", "tied_risk")
for k, v in zip(names, results):
result_data[k] = v
return result_data
Finally, after many considerations, we can create a convolutional neural network (CNN) to learn a high-level representation from MNIST digits such that we can estimate each image's survival function. The CNN follows the LeNet architecture where the last linear has one output unit that corresponds to the predicted risk score. The predicted risk score, together with the binary event indicator and risk set, are the input to the Cox PH loss.
import tensorflow.compat.v2.summary as summary
from tensorflow.python.ops import summary_ops_v2
class TrainAndEvaluateModel:
def __init__(self, model, model_dir, train_dataset, eval_dataset,
learning_rate, num_epochs):
self.num_epochs = num_epochs
self.model_dir = model_dir
self.model = model
self.train_ds = train_dataset
self.val_ds = eval_dataset
self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
self.loss_fn = CoxPHLoss()
self.train_loss_metric = tf.keras.metrics.Mean(name="train_loss")
self.val_loss_metric = tf.keras.metrics.Mean(name="val_loss")
self.val_cindex_metric = CindexMetric()
@tf.function
def train_one_step(self, x, y_event, y_riskset):
y_event = tf.expand_dims(y_event, axis=1)
with tf.GradientTape() as tape:
logits = self.model(x, training=True)
train_loss = self.loss_fn(y_true=[y_event, y_riskset], y_pred=logits)
with tf.name_scope("gradients"):
grads = tape.gradient(train_loss, self.model.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
return train_loss, logits
def train_and_evaluate(self):
ckpt = tf.train.Checkpoint(
step=tf.Variable(0, dtype=tf.int64),
optimizer=self.optimizer,
model=self.model)
ckpt_manager = tf.train.CheckpointManager(
ckpt, str(self.model_dir), max_to_keep=2)
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
print(f"Latest checkpoint restored from {ckpt_manager.latest_checkpoint}.")
train_summary_writer = summary.create_file_writer(
str(self.model_dir / "train"))
val_summary_writer = summary.create_file_writer(
str(self.model_dir / "valid"))
for epoch in range(self.num_epochs):
with train_summary_writer.as_default():
self.train_one_epoch(ckpt.step)
# Run a validation loop at the end of each epoch.
with val_summary_writer.as_default():
self.evaluate(ckpt.step)
save_path = ckpt_manager.save()
print(f"Saved checkpoint for step {ckpt.step.numpy()}: {save_path}")
def train_one_epoch(self, step_counter):
for x, y in self.train_ds:
train_loss, logits = self.train_one_step(
x, y["label_event"], y["label_riskset"])
step = int(step_counter)
if step == 0:
# see https://stackoverflow.com/questions/58843269/display-graph-using-tensorflow-v2-0-in-tensorboard
func = self.train_one_step.get_concrete_function(
x, y["label_event"], y["label_riskset"])
summary_ops_v2.graph(func.graph, step=0)
# Update training metric.
self.train_loss_metric.update_state(train_loss)
# Log every 200 batches.
if step % 200 == 0:
# Display metrics
mean_loss = self.train_loss_metric.result()
print(f"step {step}: mean loss = {mean_loss:.4f}")
# save summaries
summary.scalar("loss", mean_loss, step=step_counter)
# Reset training metrics
self.train_loss_metric.reset_states()
step_counter.assign_add(1)
@tf.function
def evaluate_one_step(self, x, y_event, y_riskset):
y_event = tf.expand_dims(y_event, axis=1)
val_logits = self.model(x, training=False)
val_loss = self.loss_fn(y_true=[y_event, y_riskset], y_pred=val_logits)
return val_loss, val_logits
def evaluate(self, step_counter):
self.val_cindex_metric.reset_states()
for x_val, y_val in self.val_ds:
val_loss, val_logits = self.evaluate_one_step(
x_val, y_val["label_event"], y_val["label_riskset"])
# Update val metrics
self.val_loss_metric.update_state(val_loss)
self.val_cindex_metric.update_state(y_val, val_logits)
val_loss = self.val_loss_metric.result()
summary.scalar("loss",
val_loss,
step=step_counter)
self.val_loss_metric.reset_states()
val_cindex = self.val_cindex_metric.result()
for key, value in val_cindex.items():
summary.scalar(key, value, step=step_counter)
print(f"Validation: loss = {val_loss:.4f}, cindex = {val_cindex['cindex']:.4f}")
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(6, kernel_size=(5, 5), activation='relu', name='conv_1'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Conv2D(16, (5, 5), activation='relu', name='conv_2'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(120, activation='relu', name='dense_1'),
tf.keras.layers.Dense(84, activation='relu', name='dense_2'),
tf.keras.layers.Dense(1, activation='linear', name='dense_3')
])
train_fn = InputFunction(x_train, time_train, event_train,
drop_last=True,
shuffle=True)
eval_fn = InputFunction(x_test, time_test, event_test)
trainer = TrainAndEvaluateModel(
model=model,
model_dir=Path("ckpts-mnist-cnn"),
train_dataset=train_fn(),
eval_dataset=eval_fn(),
learning_rate=0.0001,
num_epochs=15,
)
To obverse training, we can start TensorBoard.
# Load the TensorBoard notebook extension.
%load_ext tensorboard
%tensorboard --logdir ckpts-mnist-cnn
Let the training begin…
trainer.train_and_evaluate()
step 0: mean loss = 2.3711 step 200: mean loss = 2.0204 step 400: mean loss = 1.9728 step 600: mean loss = 1.9706 step 800: mean loss = 1.9514 Validation: loss = 1.9304, cindex = 0.6784 step 1000: mean loss = 1.9386 step 1200: mean loss = 1.9526 step 1400: mean loss = 1.9553 step 1600: mean loss = 1.9484 step 1800: mean loss = 1.9391 Validation: loss = 1.9233, cindex = 0.6831 step 2000: mean loss = 1.9381 step 2200: mean loss = 1.9446 step 2400: mean loss = 1.9407 step 2600: mean loss = 1.9428 step 2800: mean loss = 1.9339 Validation: loss = 1.9204, cindex = 0.6851 step 3000: mean loss = 1.9415 step 3200: mean loss = 1.9331 step 3400: mean loss = 1.9472 step 3600: mean loss = 1.9308 Validation: loss = 1.9186, cindex = 0.6862 step 3800: mean loss = 1.9243 step 4000: mean loss = 1.9305 step 4200: mean loss = 1.9552 step 4400: mean loss = 1.9306 step 4600: mean loss = 1.9286 Validation: loss = 1.9167, cindex = 0.6877 step 4800: mean loss = 1.9209 step 5000: mean loss = 1.9323 step 5200: mean loss = 1.9410 step 5400: mean loss = 1.9314 step 5600: mean loss = 1.9245 Validation: loss = 1.9152, cindex = 0.6886 step 5800: mean loss = 1.9296 step 6000: mean loss = 1.9273 step 6200: mean loss = 1.9388 step 6400: mean loss = 1.9274 Validation: loss = 1.9149, cindex = 0.6888 step 6600: mean loss = 1.9178 step 6800: mean loss = 1.9243 step 7000: mean loss = 1.9421 step 7200: mean loss = 1.9323 step 7400: mean loss = 1.9226 Validation: loss = 1.9137, cindex = 0.6898 step 7600: mean loss = 1.9103 step 7800: mean loss = 1.9368 step 8000: mean loss = 1.9336 step 8200: mean loss = 1.9334 step 8400: mean loss = 1.9064 Validation: loss = 1.9138, cindex = 0.6894 step 8600: mean loss = 1.9265 step 8800: mean loss = 1.9184 step 9000: mean loss = 1.9380 step 9200: mean loss = 1.9223 Validation: loss = 1.9133, cindex = 0.6900 step 9400: mean loss = 1.9153 step 9600: mean loss = 1.9147 step 9800: mean loss = 1.9377 step 10000: mean loss = 1.9291 step 10200: mean loss = 1.9240 Validation: loss = 1.9133, cindex = 0.6900 step 10400: mean loss = 1.9003 step 10600: mean loss = 1.9353 step 10800: mean loss = 1.9215 step 11000: mean loss = 1.9303 step 11200: mean loss = 1.9078 Validation: loss = 1.9129, cindex = 0.6904 step 11400: mean loss = 1.9216 step 11600: mean loss = 1.9189 step 11800: mean loss = 1.9335 step 12000: mean loss = 1.9158 Validation: loss = 1.9127, cindex = 0.6905 step 12200: mean loss = 1.9108 step 12400: mean loss = 1.9192 step 12600: mean loss = 1.9245 step 12800: mean loss = 1.9312 step 13000: mean loss = 1.9165 Validation: loss = 1.9127, cindex = 0.6906 step 13200: mean loss = 1.9004 step 13400: mean loss = 1.9269 step 13600: mean loss = 1.9200 step 13800: mean loss = 1.9296 step 14000: mean loss = 1.9080 Validation: loss = 1.9125, cindex = 0.6908 Saved checkpoint for step 14055: ckpts-mnist-cnn/ckpt-1
We can make a couple of observations:
For inference, things are much easier, we just pass a batch of images and record the predicted risk score. To estimate individual survival functions, we need to estimate the baseline hazard function $h_0$, which can be done analogous to the linear Cox PH model by using Breslow's estimator.
from sklearn.model_selection import train_test_split
from sksurv.linear_model.coxph import BreslowEstimator
class Predictor:
def __init__(self, model, model_dir):
self.model = model
self.model_dir = model_dir
def predict(self, dataset):
ckpt = tf.train.Checkpoint(
step=tf.Variable(0, dtype=tf.int64),
optimizer=tf.keras.optimizers.Adam(),
model=self.model)
ckpt_manager = tf.train.CheckpointManager(
ckpt, str(self.model_dir), max_to_keep=2)
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
print(f"Latest checkpoint restored from {ckpt_manager.latest_checkpoint}.")
risk_scores = []
for batch in dataset:
pred = self.model(batch, training=False)
risk_scores.append(pred.numpy())
return np.row_stack(risk_scores)
train_pred_fn = tf.data.Dataset.from_tensor_slices(x_train[..., np.newaxis]).batch(64)
predictor = Predictor(model, trainer.model_dir)
train_predictions = predictor.predict(train_pred_fn)
breslow = BreslowEstimator().fit(train_predictions, event_train, time_train)
Latest checkpoint restored from ckpts-mnist-cnn/ckpt-1.
Once fitted, we can use Breslow's estimator to obtain estimated survival functions for images in the test data. We randomly draw three sample images for each digit and plot their predicted survival function.
sample = train_test_split(x_test, y_test, event_test, time_test,
test_size=30, stratify=y_test, random_state=89)
x_sample, y_sample, event_sample, time_sample = sample[1::2]
sample_pred_ds = tf.data.Dataset.from_tensor_slices(
x_sample[..., np.newaxis]).batch(64)
sample_predictions = predictor.predict(sample_pred_ds)
sample_surv_fn = breslow.get_survival_function(sample_predictions)
plt.figure(figsize=(6, 4.5))
for surv_fn, class_label in zip(sample_surv_fn, y_sample):
risk_group = risk_score_assignment.loc[class_label, "risk_group"]
plt.step(surv_fn.x, surv_fn.y, where="post",
color=f"C{class_label}", linestyle=styles[risk_group])
plt.ylim(0, 1)
plt.ylabel("Probability of survival $P(T > t)$")
plt.xlabel("Time $t$")
plt.grid()
Latest checkpoint restored from ckpts-mnist-cnn/ckpt-1.
Solid lines correspond to images that belong to risk group 0 (with lowest risk), which the model was able to learn. Samples from the group with the highest risk are shown as dotted lines. Their predicted survival functions have the steepest descent, confirming that the model correctly identified different risk groups from images.
We successfully built, trained, and evaluated a convolutional neural network for survival analysis on MNIST. While MNIST is obviously not a clinical dataset, the exact same approach can be used for clinical data. For instance, Mobadersany et al. used the same approach to predict overall survival of patients diagnosed with brain tumors from microscopic images, and Zhu et al. applied CNNs to predict survival of lung cancer patients from pathological images.