Most machine learning algorithms have been developed to perform classification or regression. However, in clinical research we often want to estimate the time to and event, such as death or recurrence of cancer, which leads to a special type of learning task that is distinct from classification and regression. This task is termed survival analysis, but is also referred to as time-to-event analysis or reliability analysis. Many machine learning algorithms have been adopted to perform survival analysis: Support Vector Machines, Random Forest, or Boosting. It has only been recently that survival analysis entered the era of deep learning, which is the focus of this post.
You will learn how to train a convolutional neural network to predict time to a (generated) event from MNIST images, using a loss function specific to survival analysis. The first part, will cover some basic terms and quantities used in survival analysis (feel free to skip this part if you are already familiar). In the second part, we will generate synthetic survival data from MNIST images and visualize it. In the third part, we will briefly revisit the most popular survival model of them all and learn how it can be used as a loss function for training a neural network. Finally, we put all the pieces together and train a convolutional neural network on MNIST and predict survival functions on the test data.
Please make sure you have the following packages installed. All are available via PyPI or Anaconda.
You can also run this notebook in Google Colaboratory and install scikit-survival using the command below.
!pip install scikit-survival
from typing import Any, Dict, Iterable, Tuple, Optional, Union
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from sksurv.nonparametric import kaplan_meier_estimator
from sksurv.metrics import concordance_index_censored
import tensorflow as tf
from tensorflow.keras.datasets import mnist
print("Using Tensorflow:", tf.__version__)
tf.logging.set_verbosity(tf.logging.WARN)
Using Tensorflow: 1.15.2
from distutils.version import LooseVersion
assert LooseVersion(tf.__version__) < LooseVersion("2.0.0"), \
"This notebook requires TensorFlow 1.X."
The objective in survival analysis is to establish a connection between covariates and the time of an event. The name survival analysis originates from clinical research, where predicting the time to death, i.e., survival, is often the main objective. Survival analysis is a type of regression problem (one wants to predict a continuous value), but with a twist. It differs from traditional regression by the fact that parts of the training data can only be partially observed – they are censored.
As an example, consider a clinical study that has been carried out over a 1 year period as in the figure below.
Patient A was lost to follow-up after three months with no recorded event, patient B experienced an event four and a half months after enrollment, patient C withdrew from the study two months after enrollment, and patient E did not experience any event before the study ended. Consequently, the exact time of an event could only be recorded for patients B and D; their records are uncensored. For the remaining patients it is unknown whether they did or did not experience an event after termination of the study. The only valid information that is available for patients A, C, and E is that they were event-free up to their last follow-up. Therefore, their records are censored.
Formally, each patient record consists of the time $t>0$ when an event occurred or the time $c>0$ of censoring. Since censoring and experiencing and event are mutually exclusive, it is common to define an event indicator $\delta \in \{0;1\}$ and the observable survival time $y>0$. The observable time $y$ of a right censored time of event is defined as
$$ y = \min(t, c) = \begin{cases} t & \text{if } \delta = 1 , \\ c & \text{if } \delta = 0 . \end{cases} $$Consequently, survival analysis demands for models that take partially observed, i.e., censored, event times into account.
Typically, the survival time is modelled as a continuous non-negative random variable $T$, from which basic quantities for time-to-event analysis can be derived, most importantly, the survival function and the hazard function.
Alternative names for the hazard function are conditional failure rate, conditional mortality rate, or instantaneous failure rate. In contrast to the survival function, which describes the absence of an event, the hazard function provides information about the occurrence of an event.
To start off, we are using images from the MNIST dataset and will synthetically generate survival times based on the digit each image represents. We associate a survival time (or risk score) with each class of the ten digits in MNIST. First, we randomly assign each class label to one of four overall risk groups, such that some digits will correspond to better and others to worse survival. Next, we generate risk scores that indicate how big the risk of experiencing an event is, relative to each other.
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype(np.float32) / 255.
x_test = x_test.astype(np.float32) / 255.
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)
y = np.concatenate((y_train, y_test))
def make_risk_score_for_groups(y: np.ndarray,
n_groups: int = 4,
seed: int = 89) -> Tuple[pd.DataFrame, np.ndarray]:
rnd = np.random.RandomState(seed)
# assign class labels `y` to one of `n_groups` risk groups
classes = np.unique(y)
group_assignment = {}
group_members = {}
groups = rnd.randint(n_groups, size=classes.shape)
for label, group in zip(classes, groups):
group_assignment[label] = group
group_members.setdefault(group, []).append(label)
# assign risk score to each class label in `y`
risk_per_class = {}
for label in classes:
group_idx = group_assignment[label]
group = group_members[group_idx]
label_idx = group.index(label)
group_size = len(group)
# allow risk scores in each group to vary slightly
risk_score = np.sqrt(group_idx + 1e-4) * 1.75
risk_score -= (label_idx - (group_size // 2)) / 25.
risk_per_class[label] = risk_score
assignment = pd.concat((
pd.Series(risk_per_class, name="risk_score"),
pd.Series(group_assignment, name="risk_group")
), axis=1).rename_axis("class_label")
risk_scores = np.array([risk_per_class[yy] for yy in y])
return assignment, risk_scores
risk_score_assignment, risk_scores = make_risk_score_for_groups(y)
risk_score_assignment.round(3)
risk_score | risk_group | |
---|---|---|
class_label | ||
0 | 3.071 | 3 |
1 | 2.555 | 2 |
2 | 0.058 | 0 |
3 | 1.790 | 1 |
4 | 2.515 | 2 |
5 | 3.031 | 3 |
6 | 1.750 | 1 |
7 | 2.475 | 2 |
8 | 0.018 | 0 |
9 | 2.435 | 2 |
We can see that class labels 2 and 8 belong to risk group 0, which has the lowest risk (close to zero). Risk group 1 corresponds to a risk score of about 1.7, risk group 2 of about 2.5, and risk group 3 is the group with the highest risk score of about 3.
To generate survival times from risk scores, we are going to follow the protocol of Bender et al. We choose the exponential distribution for the survival time. Its probability density function is $f(t\,|\,\lambda) = \lambda \exp(-\lambda t)$, where $\lambda > 0$ is a scale parameter that is the inverse of the expectation: $E(T) = \frac{1}{\lambda}$. The exponential distribution results in a relatively simple time-to-event model with no memory, because the hazard rate is constant: $h(t) = \lambda$. For more complex cases, refer to the paper by Bender et al.
Here, we choose $\lambda$ such that the mean survival time is 365 days. Finally, we randomly censor survival times drawing times of censoring from a uniform distribution such that we approximately obtain the desired amount of 45% censoring. The generated survival data comprises an observed time and a boolean event indicator for each MNIST image.
class SurvivalTimeGenerator:
def __init__(self,
num_samples: int,
mean_survival_time: float,
prob_censored: float) -> None:
self.num_samples = num_samples
self.mean_survival_time = mean_survival_time
self.prob_censored = prob_censored
def gen_censored_time(self,
risk_score: np.ndarray,
seed: int = 89) -> Tuple[np.ndarray,np.ndarray]:
rnd = np.random.RandomState(seed)
# generate survival time
baseline_hazard = 1. / self.mean_survival_time
scale = baseline_hazard * np.exp(risk_score)
u = rnd.uniform(low=0, high=1, size=risk_score.shape[0])
t = -np.log(u) / scale
# generate time of censoring
qt = np.quantile(t, 1.0 - self.prob_censored)
c = rnd.uniform(low=t.min(), high=qt)
# apply censoring
observed_event = t <= c
observed_time = np.where(observed_event, t, c)
return observed_time, observed_event
surv_gen = SurvivalTimeGenerator(
num_samples=y.shape[0],
mean_survival_time=365.,
prob_censored=.45
)
time, event = surv_gen.gen_censored_time(risk_scores)
time_train = time[:y_train.shape[0]]
event_train = event[:y_train.shape[0]]
time_test = time[y_train.shape[0]:]
event_test = event[y_train.shape[0]:]
print("%.2f%% samples are right censored in training data." % (np.sum(~event_train) * 100. / len(event_train)))
print("%.2f%% samples are right censored in test data." % (np.sum(~event_test) * 100. / len(event_test)))
46.19% samples are right censored in training data. 46.33% samples are right censored in test data.
We can use the generated censored data and estimate the survival function $S(t)$ to see what the risk scores actually mean in terms of survival. We stratify the training data by class label, and estimate the corresponding survival function using the non-parametric Kaplan-Meier estimator.
styles = ('-', '--', '-.', ':')
plt.figure(figsize=(6, 4.5))
for row in risk_score_assignment.itertuples():
mask = y_train == row.Index
coord_x, coord_y = kaplan_meier_estimator(event_train[mask], time_train[mask])
ls = styles[row.risk_group]
plt.step(coord_x, coord_y, where="post", label=f"Class {row.Index}", linestyle=ls)
plt.ylim(0, 1)
plt.ylabel("Probability of survival $P(T > t)$")
plt.xlabel("Time $t$")
plt.grid()
plt.legend()
<matplotlib.legend.Legend at 0x7f2bc14bb910>
Classes 0 and 5 (dotted lines) correspond to risk group 3, which has the highest risk score. The corresponding survival functions drop most quickly, which is exactly what we wanted. On the other end of the spectrum are classes 2 and 8 (solid lines) belonging to risk group 0 with the lowest risk.
One important aspect for survival analysis is that both the training data and the test data are subject to censoring, because we are unable to observe the exact time of an event no matter how the data was split. Therefore, performance measures need to account for censoring. The most widely used performance measure is Harrell's concordance index. Given a set of (predicted) risk scores and observed times, it checks whether the ordering by risk scores is concordant with the ordering by actual survival time. While Harrell's concordance index is widely used, it has its flaws, in particular when data is highly censored. Please refer to my previous post on evaluating survival models for more details.
We can take the risk score from which we generated survival times to check how good a model would perform if we knew the actual risk score.
cindex = concordance_index_censored(event_test, time_test, risk_scores[y_train.shape[0]:])
print(f"Concordance index on test data with actual risk scores: {cindex[0]:.3f}")
Concordance index on test data with actual risk scores: 0.705
Surprisingly, we do not obtain a perfect result of 1.0. The reason for this is that generated survival times are randomly distributed based on risk scores and not deterministic functions of the risk score. Therefore, any model we will train on this data should not be able to exceed this performance value.
By far the most widely used model to learn from censored survival data, is Cox's proportional hazards model model. It models the hazard function $h(t_i)$ of the $i$-th subject, conditional on the feature vector $\mathbf{x}_i \in \mathbb{R}^p$, as the product of an unspecified baseline hazard function $h_0$ (more on that later) and an exponential function of the linear model $\mathbf{x}_i^\top \mathbf{\beta}$: $$ h(t | x_{i1}, \ldots, x_{ip}) = h_0(t) \exp \left( \sum_{j=1}^p x_{ij} \beta_j \right) \Leftrightarrow \log \frac{h(t | \mathbf{x}_i)}{h_0 (t)} = \mathbf{x}_i^\top \mathbf{\beta} , $$ where $\mathbf{\beta} \in \mathbb{R}^p$ are the coefficients associated with each of the $p$ features, and no intercept term is included in the model. The key is that the hazard function is split into two parts: the baseline hazard function $h_0$ only depends on the time $t$, whereas the exponential is independent of time and only depends on the covariates $\mathbf{x}_i$.
Cox's proportional hazards model is fitted by maximizing the partial likelihood function, which is based on the probability that the $i$-th individual experiences an event at time $t_i$, given that there is one event at time point $t_i$. As we will see, by specifying the hazard function as above, the baseline hazard function $h_0$ can be eliminated and does not need be defined for finding the coefficients $\mathbf{\beta}$. Let $\mathcal{R}_i = \{ j\,|\,y_j \geq y_i \}$ be the risk set, i.e., the set of subjects who remained event-free shortly before time point $y_i$, and $I(\cdot)$ the indicator function, then we have
$$ \begin{split} &P(\text{subject experiences event at $y_i$} \mid \text{one event at $y_i$}) \\ =& \frac{P(\text{subject experiences event at $y_i$} \mid \text{event-free up to $y_i$})} {P (\text{one event at $y_i$} \mid \text{event-free up to $y_i$})} \\ =& \frac{h(y_i | \mathbf{x}_i)}{ \sum_{j=1}^n I(y_j \geq y_i) h(y_j | \mathbf{x}_j) } \\ =& \frac{h_0(y_i) \exp(\mathbf{x}_i^\top \mathbf{\beta})} { \sum_{j=1}^n I(y_j \geq y_i) h_0(y_j) \exp(\mathbf{x}_j^\top \mathbf{\beta}) } \\ =& \frac{\exp( \mathbf{x}_i^\top \beta)}{\sum_{j \in \mathcal{R}_i} \exp( \mathbf{x}_j^\top \beta)} . \end{split} $$By multiplying the conditional probability from above for all patients who experienced an event, and taking the logarithm, we obtain the partial likelihood function:
$$ \widehat{\mathbf{\beta}} = \arg\max_{\mathbf{\beta}}~ \log\,PL(\mathbf{\beta}) = \sum_{i=1}^n \delta_i \left[ \mathbf{x}_i^\top \mathbf{\beta} - \log \left( \sum_{j \in \mathcal{R}_i} \exp( \mathbf{x}_j^\top \mathbf{\beta}) \right) \right] . $$Cox's proportional hazards model as described above is a linear model, i.e., the predicted risk score is a linear combination of features. However, the model can easily be extended to the non-linear case by just replacing the linear predictor with the output of a neural network with parameters $\mathbf{\Theta}$.
This has been realized early on and was originally proposed in the work of Faraggi and Simon back in 1995. Farragi and Simon explore multilayer perceptrons, but the same loss can be used in combination with more advanced architectures such as convolutional neural networks or recurrent neural networks. Therefore, it is natural to also use the same loss function in the era of deep learning. However, this transition is not so easy as it may seem and comes with some caveats, both for training and for evaluation.
When implementing the Cox PH loss function, the problematic part is the inner sum over the risk set: $\sum_{j \in \mathcal{R}_i} \exp( \mathbf{x}_j^\top \mathbf{\beta})$. Note that the risk set is defined as $\mathcal{R}_i = \{ j\,|\,y_j \geq y_i \}$, which implies an ordering according to observed times $y_i$, which may lead to quadratic complexity if implemented naively. Ideally, we want to sort the data once in descending order by survival time and then incrementally update the inner sum, which leads to a linear complexity to compute the loss (ignoring the time for sorting).
Another problem is that the risk set for the subject with the smallest uncensored survival time is over the whole dataset. This is usually impractical, because we may not be able to keep the whole dataset in GPU memory. If we use mini-batches instead, as it's the norm, (i) we cannot compute the exact loss, because we may not have access to all samples in the risk set, and (ii) we need to sort each mini-batch by observed time, instead of sorting the whole data once.
For practical purposes, computing the Cox PH loss over a mini-batch is usually fine, as long as the batch contains several uncensored samples, because otherwise the outer sum in the partial likelihood function would be over an empty set. Here, we implement the sum over the risk set by multiplying the exponential of the predictions (as a row vector) by a squared boolean matrix that contains each sample's risk set as its rows. The sum over the risk set for each sample is then equivalent to a row-wise summation.
def _make_riskset(time: np.ndarray) -> np.ndarray:
"""Compute mask that represents each sample's risk set.
Parameters
----------
time : np.ndarray, shape=(n_samples,)
Observed event time sorted in descending order.
Returns
-------
risk_set : np.ndarray, shape=(n_samples, n_samples)
Boolean matrix where the `i`-th row denotes the
risk set of the `i`-th instance, i.e. the indices `j`
for which the observer time `y_j >= y_i`.
"""
assert time.ndim == 1, "expected 1D array"
# sort in descending order
o = np.argsort(-time, kind="mergesort")
n_samples = len(time)
risk_set = np.zeros((n_samples, n_samples), dtype=np.bool_)
for i_org, i_sort in enumerate(o):
ti = time[i_sort]
k = i_org
while k < n_samples and ti == time[o[k]]:
k += 1
risk_set[i_sort, o[:k]] = True
return risk_set
class InputFunction:
"""Callable input function that computes the risk set for each batch.
Parameters
----------
images : np.ndarray, shape=(n_samples, height, width)
Image data.
time : np.ndarray, shape=(n_samples,)
Observed time.
event : np.ndarray, shape=(n_samples,)
Event indicator.
num_epochs : int, optional, default=1
Number of epochs.
batch_size : int, optional, default=64
Number of samples per batch.
drop_last : int, optional, default=False
Whether to drop the last incomplete batch.
shuffle : bool, optional, default=False
Whether to shuffle data.
seed : int, optional, default=89
Random number seed.
"""
def __init__(self,
images: np.ndarray,
time: np.ndarray,
event: np.ndarray,
num_epochs: int= 1,
batch_size: int = 64,
drop_last: bool = False,
shuffle: bool = False,
seed: int = 89) -> None:
if images.ndim == 3:
images = images[..., np.newaxis]
self.images = images
self.time = time
self.event = event
self.num_epochs = num_epochs
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle = shuffle
self.seed = seed
def size(self) -> int:
"""Total number of samples."""
return self.images.shape[0]
def steps_per_epoch(self) -> int:
"""Number of batches for one epoch."""
return int(np.floor(self.size() / self.batch_size))
def _get_data_batch(self, index: np.ndarray) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
"""Compute risk set for samples in batch."""
time = self.time[index]
event = self.event[index]
images = self.images[index]
labels = {
"label_event": event.astype(np.int32),
"label_time": time.astype(np.float32),
"label_riskset": _make_riskset(time)
}
return images, labels
def _iter_data(self) -> Iterable[Tuple[np.ndarray, Dict[str, np.ndarray]]]:
"""Generator that yields one batch at a time."""
index = np.arange(self.size())
rnd = np.random.RandomState(self.seed)
for _ in range(self.num_epochs):
if self.shuffle:
rnd.shuffle(index)
for b in range(self.steps_per_epoch()):
start = b * self.batch_size
idx = index[start:(start + self.batch_size)]
yield self._get_data_batch(idx)
if not self.drop_last:
start = self.steps_per_epoch() * self.batch_size
idx = index[start:]
yield self._get_data_batch(idx)
def _get_shapes(self) -> Tuple[tf.TensorShape, Dict[str, tf.TensorShape]]:
"""Return shapes of data returned by `self._iter_data`."""
batch_size = self.batch_size if self.drop_last else None
h, w, c = self.images.shape[1:]
images = tf.TensorShape([batch_size, h, w, c])
labels = {k: tf.TensorShape((batch_size,))
for k in ("label_event", "label_time")}
labels["label_riskset"] = tf.TensorShape((batch_size, batch_size))
return images, labels
def _get_dtypes(self) -> Tuple[tf.DType, Dict[str, tf.DType]]:
"""Return dtypes of data returned by `self._iter_data`."""
labels = {"label_event": tf.int32,
"label_time": tf.float32,
"label_riskset": tf.bool}
return tf.float32, labels
def _make_dataset(self) -> tf.data.Dataset:
"""Create dataset from generator."""
ds = tf.data.Dataset.from_generator(
self._iter_data,
self._get_dtypes(),
self._get_shapes()
)
return ds
def __call__(self) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]:
ds = self._make_dataset()
next_x, next_y = ds.make_one_shot_iterator().get_next()
return next_x, next_y
def safe_normalize(x: tf.Tensor) -> tf.Tensor:
"""Normalize risk scores to avoid exp underflowing.
Note that only risk scores relative to each other matter.
If minimum risk score is negative, we shift scores so minimum
is at zero.
"""
x_min = tf.reduce_min(x, axis=0)
c = tf.zeros_like(x_min)
norm = tf.where(x_min < 0, -x_min, c)
return x + norm
def logsumexp_masked(risk_scores: tf.Tensor,
mask: tf.Tensor,
axis: int = 0,
keepdims: Optional[bool] = None) -> tf.Tensor:
"""Compute logsumexp across `axis` for entries where `mask` is true."""
risk_scores.get_shape().assert_same_rank(mask.get_shape())
with tf.name_scope("logsumexp_masked", values=[risk_scores, mask]):
mask_f = tf.cast(mask, risk_scores.dtype)
risk_scores_masked = tf.multiply(risk_scores, mask_f)
# for numerical stability, substract the maximum value
# before taking the exponential
amax = tf.reduce_max(risk_scores_masked, axis=axis, keepdims=True)
risk_scores_shift = risk_scores_masked - amax
exp_masked = tf.multiply(tf.exp(risk_scores_shift), mask_f)
exp_sum = tf.reduce_sum(exp_masked, axis=axis, keepdims=True)
output = amax + tf.log(exp_sum)
if not keepdims:
output = tf.squeeze(output, axis=axis)
return output
def coxph_loss(event: tf.Tensor,
riskset: tf.Tensor,
predictions: tf.Tensor,
weights: Union[tf.Tensor, float] = 1.0,
scope: Optional[str] = None,
loss_collection: str = tf.GraphKeys.LOSSES,
reduction: str = tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS):
"""Negative partial log-likelihood of Cox's proportional
hazards model.
Parameters
----------
event : tf.Tensor
Binary vector where 1 indicates an event 0 censoring.
riskset : tf.Tensor
Boolean matrix where the `i`-th row denotes the
risk set of the `i`-th instance, i.e. the indices `j`
for which the observer time `y_j >= y_i`.
predictions : tf.Tensor
The predicted outputs. Must be a rank 2 tensor.
weights : tf.Tensor|float
Weight of loss. Either a scalar or a Tensor containing
weights of each instance.
scope : str|None
The scope for the operations performed in computing the loss.
loss_collection : str
Collection to which the loss will be added.
reduction : str
Type of reduction to apply to loss.
Returns
-------
loss : tf.Tensor
Scalar loss.
"""
pred_shape = predictions.get_shape()
if pred_shape.ndims != 2:
raise ValueError("Rank mismatch: Rank of predictions (received %s) should "
"be 2." % pred_shape.ndims)
if pred_shape[1].value is None:
raise ValueError("Last dimension of predictions must be known.")
if pred_shape[1].value != 1:
raise ValueError("Dimension mismatch: Last dimension of predictions "
"(received %s) must be 1." % pred_shape[1].value)
if event.get_shape().ndims != pred_shape.ndims:
raise ValueError("Rank mismatch: Rank of predictions (received %s) should "
"equal rank of event (received %s)" % (
pred_shape.ndims, event.get_shape().ndims))
if riskset.get_shape().ndims != 2:
raise ValueError("Rank mismatch: Rank of riskset (received %s) should "
"be 2." % riskset.get_shape().ndims)
with tf.name_scope(scope, 'coxph_loss', [predictions, event, riskset]) as scope:
event = tf.cast(event, predictions.dtype)
predictions = safe_normalize(predictions)
with tf.name_scope('assertions', values=[event, riskset]):
assertions = (
tf.assert_less_equal(event, 1.),
tf.assert_greater_equal(event, 0.),
tf.assert_type(riskset, tf.bool)
)
with tf.control_dependencies(assertions):
# move batch dimension to the end so predictions get broadcast
# row-wise when multiplying by riskset
pred_t = tf.transpose(predictions)
# compute log of sum over risk set for each row
rr = logsumexp_masked(pred_t, riskset, axis=1, keepdims=True)
assert rr.get_shape().as_list() == predictions.get_shape().as_list()
losses = tf.multiply(event, rr - predictions)
loss = tf.losses.compute_weighted_loss(
losses,
weights=weights,
scope=scope,
loss_collection=loss_collection,
reduction=reduction)
return loss
To monitor the training process, we would like to compute the concordance index with respect to a separate validation set. Similar to the Cox PH loss, the concordance index needs access to predicted risk scores and ground truth of all samples in the validation data. While we had to opt for computing the Cox PH loss over a mini-batch, I would not recommend this for the validation data. For small batch sizes and/or high amount of censoring, the estimated concordance index would be quite volatile, which makes it very hard to interpret. In addition, the validation data is usually considerably smaller than the training data, therefore we can collect predictions for the whole validation data and compute the concordance index accurately.
from tensorflow.core.framework import summary_pb2
def _to_scalar_protobuf(name: str, value: Any) -> summary_pb2.Summary:
value = float(value)
buf = summary_pb2.Summary(value=[summary_pb2.Summary.Value(
tag=name, simple_value=value)])
return buf
class EvalCindexHook(tf.train.SessionRunHook):
"""Computes concordance index across one epoch.
Collects ground truth and predicted risk score
until session ends, i.e., `OutOfRangeError` is raised.
The concordance index is computed across collected
values and written to protocol buffer to display
in TensorBoard.
Parameters
----------
label_time : tf.Tensor
Tensor containing observed time for one batch.
label_event_indicator : tf.Tensor
Tensor containing event indicator for one batch.
prediction : tf.Tensor
Tensor containing predicted risk score for one batch.
"""
def __init__(self,
label_time: tf.Tensor,
label_event_indicator: tf.Tensor,
prediction: tf.Tensor,
output_dir: str) -> None:
self._label_time = label_time
self._label_event_indicator = label_event_indicator
self._prediction = prediction
self._writer = tf.summary.FileWriterCache.get(output_dir)
def begin(self) -> None:
self._global_step_tensor = tf.train.get_or_create_global_step()
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created.")
self._data = {
"label_time": [],
"label_event": [],
"prediction": []
}
self._next_step = None
def before_run(self,
run_context: tf.train.SessionRunContext) -> tf.train.SessionRunArgs:
fetches = {
"global_step": self._global_step_tensor,
"label_time": self._label_time,
"label_event": self._label_event_indicator,
"prediction": self._prediction
}
return tf.train.SessionRunArgs(fetches=fetches)
def after_run(self,
run_context: tf.train.SessionRunContext,
run_values: tf.train.SessionRunValues) -> None:
global_step = run_values.results["global_step"]
if self._next_step is None:
self._writer.add_session_log(
tf.SessionLog(status=tf.SessionLog.START), global_step)
for k, v in self._data.items():
v.append(run_values.results[k].squeeze())
self._next_step = global_step + 1
def _log_and_write(self, global_step: int, results: Dict[str, float]) -> None:
msg = [f"global_step = {global_step}"]
for k, v in results.items():
msg.append(f"{k} = {v:.3f}")
buf = _to_scalar_protobuf(f"metrics/{k}", v)
self._writer.add_summary(buf, global_step=global_step)
tf.logging.info(", ".join(msg))
def end(self, session: tf.Session) -> None:
if self._next_step is None:
return
data = {}
for k, v in self._data.items():
data[k] = np.concatenate(v)
results = concordance_index_censored(
data["label_event"] == 1,
data["label_time"],
data["prediction"])
data = {}
names = ("cindex", "concordant", "discordant", "tied_risk")
for k, v in zip(names, results):
data[k] = v
self._log_and_write(self._next_step - 1, data)
del self._data
self._writer.flush()
Finally, after many considerations, we can create a convolutional neural network (CNN) to learn a high-level representation from MNIST digits such that we can estimate each image's survival function. The CNN follows the LeNet architecture where the last linear has one output unit that corresponds to the predicted risk score. The predicted risk score, together with the binary event indicator and risk set, are the input to the Cox PH loss.
def model_fn(features, labels, mode, params):
is_training = mode == tf.estimator.ModeKeys.TRAIN
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(6, kernel_size=(5, 5), activation='relu', name='conv_1'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Conv2D(16, (5, 5), activation='relu', name='conv_2'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(120, activation='relu', name='dense_1'),
tf.keras.layers.Dense(84, activation='relu', name='dense_2'),
tf.keras.layers.Dense(1, activation='linear', name='dense_3')
])
risk_score = model(features, training=is_training)
loss = None
train_op = None
predictions = None
evaluation_hooks = None
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = {"risk_score": risk_score}
else:
loss = coxph_loss(
event=tf.expand_dims(labels["label_event"], axis=1),
riskset=labels["label_riskset"],
predictions=risk_score
)
if is_training:
optim = tf.train.AdamOptimizer(learning_rate=params["learning_rate"])
gs = tf.train.get_or_create_global_step()
train_op = tf.contrib.layers.optimize_loss(loss, gs,
learning_rate=None,
optimizer=optim)
else:
evaluation_hooks = [EvalCindexHook(
label_time=labels["label_time"],
label_event_indicator=labels["label_event"],
prediction=risk_score,
output_dir=Path(params["model_dir"]) / "cindex")]
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
predictions=predictions,
evaluation_hooks=evaluation_hooks
)
train_spec = tf.estimator.TrainSpec(
InputFunction(x_train, time_train, event_train,
num_epochs=15,
drop_last=True,
shuffle=True)
)
eval_spec = tf.estimator.EvalSpec(
InputFunction(x_test, time_test, event_test),
steps=None,
start_delay_secs=10,
throttle_secs=10,
)
params = {
"learning_rate": 0.0001,
"model_dir": "ckpts-mnist-cnn",
}
config = tf.estimator.RunConfig(
model_dir=params["model_dir"],
save_checkpoints_steps=train_spec.input_fn.steps_per_epoch(),
)
estimator = tf.estimator.Estimator(model_fn, config=config, params=params)
To obverse training, we can start TensorBoard (requires version 1.13 or later).
# Load the TensorBoard notebook extension.
%load_ext tensorboard
%tensorboard --logdir ckpts-mnist-cnn
Let the training begin…
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
WARNING:tensorflow:From /root/miniconda/envs/tfsurv/lib/python3.7/site-packages/tensorflow_core/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. WARNING:tensorflow:From <ipython-input-8-918c5e42634c>:142: DatasetV1.make_one_shot_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`. WARNING:tensorflow:From /root/miniconda/envs/tfsurv/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers. WARNING:tensorflow:From <ipython-input-8-918c5e42634c>:155: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.where in 2.0, which has the same broadcast rule as np.where WARNING:tensorflow: The TensorFlow contrib module will not be included in TensorFlow 2.0. For more information, please see: * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md * https://github.com/tensorflow/addons * https://github.com/tensorflow/io (for I/O related ops) If you depend on functionality not listed there, please file an issue. WARNING:tensorflow:From /root/miniconda/envs/tfsurv/lib/python3.7/site-packages/tensorflow_core/python/training/saver.py:1069: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file utilities to get mtimes. WARNING:tensorflow:From /root/miniconda/envs/tfsurv/lib/python3.7/site-packages/tensorflow_core/python/training/saver.py:963: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to delete files with this prefix.
({'loss': 1.9082849, 'global_step': 14055}, [])
We can make a couple of observations:
For inference, things are much easier, we just pass a batch of images and record the predicted risk score. To estimate individual survival functions, we need to estimate the baseline hazard function $h_0$, which can be done analogous to the linear Cox PH model by using Breslow's estimator.
from sklearn.model_selection import train_test_split
from sksurv.linear_model.coxph import BreslowEstimator
def make_pred_fn(images: np.ndarray,
batch_size: int = 64):
if images.ndim == 3:
images = images[..., np.newaxis]
def _input_fn():
ds = tf.data.Dataset.from_tensor_slices(images)
ds = ds.batch(batch_size)
next_x = ds.make_one_shot_iterator().get_next()
return next_x, None
return _input_fn
train_pred_fn = make_pred_fn(x_train)
train_predictions = np.array([float(pred["risk_score"])
for pred in estimator.predict(train_pred_fn)])
breslow = BreslowEstimator().fit(train_predictions, event_train, time_train)
Once fitted, we can use Breslow's estimator to obtain estimated survival functions for images in the test data. We randomly draw three sample images for each digit and plot their predicted survival function.
sample = train_test_split(x_test, y_test, event_test, time_test,
test_size=30, stratify=y_test, random_state=89)
x_sample, y_sample, event_sample, time_sample = sample[1::2]
sample_pred_fn = make_pred_fn(x_sample)
sample_predictions = np.array([float(pred["risk_score"])
for pred in estimator.predict(sample_pred_fn)])
sample_surv_fn = breslow.get_survival_function(sample_predictions)
plt.figure(figsize=(6, 4.5))
for surv_fn, class_label in zip(sample_surv_fn, y_sample):
risk_group = risk_score_assignment.loc[class_label, "risk_group"]
plt.step(surv_fn.x, surv_fn.y, where="post",
color=f"C{class_label}", linestyle=styles[risk_group])
plt.ylim(0, 1)
plt.ylabel("Probability of survival $P(T > t)$")
plt.xlabel("Time $t$")
plt.grid()
Solid lines correspond to images that belong to risk group 0 (with lowest risk), which the model was able to learn. Samples from the group with the highest risk are shown as dotted lines. Their predicted survival functions have the steepest descent, confirming that the model correctly identified different risk groups from images.
We successfully built, trained, and evaluated a convolutional neural network for survival analysis on MNIST. While MNIST is obviously not a clinical dataset, the exact same approach can be used for clinical data. For instance, Mobadersany et al. used the same approach to predict overall survival of patients diagnosed with brain tumors from microscopic images, and Zhu et al. applied CNNs to predict survival of lung cancer patients from pathological images.