In case of problems or questions, please first check the list of Frequently Asked Questions (FAQ).
Please shutdown all other training/prediction notebooks before running this notebook (as those might occupy the GPU memory otherwise).
The main purpose of StarDist is to detect all object instances in an input image, e.g. all cell nuclei in a fluorescence microscopy image as shown here.
This notebook demonstrates how StarDist can additionally classify each found object instance into a fixed number of different object classes (e.g. cell types, phenotypes, etc.). We will refer to this as approach as multi-class in the following.
To use multi-class StarDist, one has to provide for every training input image X
and associated label instance mask Y
an additional class dictionary cls_dict
, that maps instance ids to a discrete set of class labels label_id -> (1,..., n_classes)
.
Difference to "normal" StarDist training:
n_classes
variable to the number of object classes in the Config
object.classes
variable in model.train
to a list of class dictionaries for every training image/label pair.In the following we demonstrate this workflow for the case of a synthetic dataset consisting of 2 cell phenotypes of different texture.
NOTE: Although this example uses 2D images, the demonstrated functionality also works for 3D StarDist.
from __future__ import print_function, unicode_literals, absolute_import, division
import sys
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = 'none'
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from glob import glob
import json
from tqdm import tqdm
from tifffile import imread
from csbdeep.utils import Path, normalize
from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available
from stardist.matching import matching, matching_dataset
from stardist.models import Config2D, StarDist2D, StarDistData2D
from stardist.utils import mask_to_categorical
from stardist.plot import render_label
np.random.seed(0)
lbl_cmap = random_label_cmap()
lbl_cmap_classes = matplotlib.cm.tab20
def plot_img_label(img, lbl, cls_dict, n_classes=2, img_title="image", lbl_title="label", cls_title="classes", **kwargs):
c = mask_to_categorical(lbl, n_classes=n_classes, classes=cls_dict)
res = np.zeros(lbl.shape, np.uint16)
for i in range(1,c.shape[-1]):
m = c[...,i]>0
res[m] = i
class_img = lbl_cmap_classes(res)
class_img[...,:3][res==0] = 0
class_img[...,-1][res==0] = 1
fig, (ai,al,ac) = plt.subplots(1,3, figsize=(17,7), gridspec_kw=dict(width_ratios=(1.,1,1)))
im = ai.imshow(img, cmap='gray')
#fig.colorbar(im, ax = ai)
ai.set_title(img_title)
al.imshow(render_label(lbl, .8*normalize(img, clip=True), normalize_img=False, alpha_boundary=.8,cmap=lbl_cmap))
al.set_title(lbl_title)
ac.imshow(class_img)
ac.imshow(render_label(res, .8*normalize(img, clip=True), normalize_img=False, alpha_boundary=.8, cmap=lbl_cmap_classes))
ac.set_title(cls_title)
plt.tight_layout()
for a in ai,al,ac:
a.axis("off")
return ai,al,ac
# set the number of object classes
n_classes = 2
# generate synthetic samples of 2D images, label masks, and class dicts
def generate_sample(n=256):
from scipy.ndimage.morphology import distance_transform_edt
from skimage.segmentation import watershed, relabel_sequential
from skimage.morphology import disk, binary_erosion, binary_closing
from scipy import ndimage as ndi
m = np.random.randint(10,30)
center = tuple(np.random.randint(0,n,(2,m)))
markers = np.zeros((n,n),np.uint16)
markers[center] = np.random.permutation(np.arange(1,m+1))
dist = distance_transform_edt(1-1*(markers>0)) + 4*ndi.zoom(np.random.uniform(0,1,(n//16,n//16)), (16,16), order=1)
dist *= np.random.uniform(1,2)
y = watershed(dist, markers, mask=np.exp(-.1*dist)>.1)
y = relabel_sequential(np.max(np.stack([i*binary_closing(y==i,disk(3)) for i in np.unique(y[y>0])],axis=0),axis=0))[0]
ind = np.arange(1,y.max()+1)
np.random.shuffle(ind)
c1, c2 = ind[:len(ind)//2], ind[len(ind)//2:]
m1 = np.isin(y,c1)
m2 = np.isin(y,c2)
x = ndi.filters.gaussian_filter((y>0).astype(np.float32),2)
noise1 = ndi.zoom(np.random.uniform(0,1,(n//4,n//4)), (4,4), order=3)
noise2 = ndi.zoom(np.random.uniform(0,1,(n//16,n//16)), (16,16), order=3)
x[m1] = .2*(1+np.sin(y[m1]))+noise1[m1]
x[m2] = .2*(1+np.sin(y[m2]))+noise2[m2]
x = ndi.filters.gaussian_filter(x,1)+2*ndi.filters.gaussian_filter(x,20)
gain = 100
x = np.random.poisson((1+gain*x).astype(int))/gain
x += .08*np.random.normal(0,1,x.shape)
cls = dict((i,1 if i in c1 else 2) for i in range(1,m+1))
return x, y, cls
# plot an example
np.random.seed(42)
x,y,cls_dict = generate_sample()
ax = plot_img_label(x,y,cls_dict, n_classes=n_classes)
for a in ax: a.axis("off");
# show the cls_dict for this example. {label_id -> class_number}
print(cls_dict)
{1: 2, 2: 1, 3: 1, 4: 2, 5: 1, 6: 2, 7: 1, 8: 1, 9: 1, 10: 2, 11: 1, 12: 1, 13: 2, 14: 2, 15: 2, 16: 2}
Generate a synthetic training set of 100 random images, label masks, and class dictionaries.
X, Y, C = tuple(zip(*tuple(generate_sample() for _ in tqdm(range(100)))))
assert len(X) == len(Y) == len(C)
100%|██████████| 100/100 [00:06<00:00, 15.49it/s]
n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]
Normalize images and fill small label holes.
axis_norm = (0,1) # normalize channels independently
# axis_norm = (0,1,2) # normalize channels jointly
if n_channel > 1:
print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))
sys.stdout.flush()
X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]
Y = [fill_label_holes(y) for y in tqdm(Y)]
100%|██████████| 100/100 [00:00<00:00, 1003.02it/s] 100%|██████████| 100/100 [00:00<00:00, 646.88it/s]
Split into train and validation datasets.
assert len(X) > 1, "not enough training data"
rng = np.random.RandomState(42)
ind = rng.permutation(len(X))
n_val = max(1, int(round(0.15 * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val, C_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val] , [C[i] for i in ind_val]
X_trn, Y_trn, C_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train], [C[i] for i in ind_train]
print('number of images: %3d' % len(X))
print('- training: %3d' % len(X_trn))
print('- validation: %3d' % len(X_val))
number of images: 100 - training: 85 - validation: 15
Training data consists of pairs of input image and label instances.
i = min(8, len(X)-1)
img, lbl, cls = X[i], Y[i], C[i]
assert img.ndim in (2,3)
img = img if (img.ndim==2 or img.shape[-1]==3) else img[...,0]
plot_img_label(img,lbl, cls)
None;
A StarDist2D
model is specified via a Config2D
object.
print(Config2D.__doc__)
Configuration for a :class:`StarDist2D` model. Parameters ---------- axes : str or None Axes of the input images. n_rays : int Number of radial directions for the star-convex polygon. Recommended to use a power of 2 (default: 32). n_channel_in : int Number of channels of given input image (default: 1). grid : (int,int) Subsampling factors (must be powers of 2) for each of the axes. Model will predict on a subsampled grid for increased efficiency and larger field of view. n_classes : None or int Number of object classes to use for multi-class predection (use None to disable) backbone : str Name of the neural network architecture to be used as backbone. kwargs : dict Overwrite (or add) configuration attributes (see below). Attributes ---------- unet_n_depth : int Number of U-Net resolution levels (down/up-sampling layers). unet_kernel_size : (int,int) Convolution kernel size for all (U-Net) convolution layers. unet_n_filter_base : int Number of convolution kernels (feature channels) for first U-Net layer. Doubled after each down-sampling layer. unet_pool : (int,int) Maxpooling size for all (U-Net) convolution layers. net_conv_after_unet : int Number of filters of the extra convolution layer after U-Net (0 to disable). unet_* : * Additional parameters for U-net backbone. train_shape_completion : bool Train model to predict complete shapes for partially visible objects at image boundary. train_completion_crop : int If 'train_shape_completion' is set to True, specify number of pixels to crop at boundary of training patches. Should be chosen based on (largest) object sizes. train_patch_size : (int,int) Size of patches to be cropped from provided training images. train_background_reg : float Regularizer to encourage distance predictions on background regions to be 0. train_foreground_only : float Fraction (0..1) of patches that will only be sampled from regions that contain foreground pixels. train_sample_cache : bool Activate caching of valid patch regions for all training images (disable to save memory for large datasets) train_dist_loss : str Training loss for star-convex polygon distances ('mse' or 'mae'). train_loss_weights : tuple of float Weights for losses relating to (probability, distance) train_epochs : int Number of training epochs. train_steps_per_epoch : int Number of parameter update steps per epoch. train_learning_rate : float Learning rate for training. train_batch_size : int Batch size for training. train_n_val_patches : int Number of patches to be extracted from validation images (``None`` = one patch per image). train_tensorboard : bool Enable TensorBoard for monitoring training progress. train_reduce_lr : dict Parameter :class:`dict` of ReduceLROnPlateau_ callback; set to ``None`` to disable. use_gpu : bool Indicate that the data generator should use OpenCL to do computations on the GPU. .. _ReduceLROnPlateau: https://keras.io/api/callbacks/reduce_lr_on_plateau/
n_classes
variable to the number of object classes in the Config
object.classes
variable in model.train
to a list of class dictionaries for every training image/label pair.# 32 is a good default choice
n_rays = 32
# Use OpenCL-based computations for data generator during training (requires 'gputools')
use_gpu = True and gputools_available()
# Predict on subsampled grid for increased efficiency and larger field of view
grid = (2,2)
conf = Config2D (
n_rays = n_rays,
grid = grid,
use_gpu = use_gpu,
n_channel_in = n_channel,
n_classes = n_classes, # set the number of object classes
)
print(conf)
vars(conf)
Config2D(axes='YXC', backbone='unet', grid=(2, 2), n_channel_in=1, n_channel_out=33, n_classes=2, n_dim=2, n_rays=32, net_conv_after_unet=128, net_input_shape=(None, None, 1), net_mask_shape=(None, None, 1), train_background_reg=0.0001, train_batch_size=4, train_checkpoint='weights_best.h5', train_checkpoint_epoch='weights_now.h5', train_checkpoint_last='weights_last.h5', train_class_weights=(1, 1, 1), train_completion_crop=32, train_dist_loss='mae', train_epochs=400, train_foreground_only=0.9, train_learning_rate=0.0003, train_loss_weights=(1, 0.2, 1), train_n_val_patches=None, train_patch_size=(256, 256), train_reduce_lr={'factor': 0.5, 'patience': 40, 'min_delta': 0}, train_sample_cache=True, train_shape_completion=False, train_steps_per_epoch=100, train_tensorboard=True, unet_activation='relu', unet_batch_norm=False, unet_dropout=0.0, unet_kernel_size=(3, 3), unet_last_activation='relu', unet_n_conv_per_depth=2, unet_n_depth=3, unet_n_filter_base=32, unet_pool=(2, 2), unet_prefix='', use_gpu=True)
{'n_dim': 2, 'axes': 'YXC', 'n_channel_in': 1, 'n_channel_out': 33, 'train_checkpoint': 'weights_best.h5', 'train_checkpoint_last': 'weights_last.h5', 'train_checkpoint_epoch': 'weights_now.h5', 'n_rays': 32, 'grid': (2, 2), 'backbone': 'unet', 'n_classes': 2, 'unet_n_depth': 3, 'unet_kernel_size': (3, 3), 'unet_n_filter_base': 32, 'unet_n_conv_per_depth': 2, 'unet_pool': (2, 2), 'unet_activation': 'relu', 'unet_last_activation': 'relu', 'unet_batch_norm': False, 'unet_dropout': 0.0, 'unet_prefix': '', 'net_conv_after_unet': 128, 'net_input_shape': (None, None, 1), 'net_mask_shape': (None, None, 1), 'train_shape_completion': False, 'train_completion_crop': 32, 'train_patch_size': (256, 256), 'train_background_reg': 0.0001, 'train_foreground_only': 0.9, 'train_sample_cache': True, 'train_dist_loss': 'mae', 'train_loss_weights': (1, 0.2, 1), 'train_class_weights': (1, 1, 1), 'train_epochs': 400, 'train_steps_per_epoch': 100, 'train_learning_rate': 0.0003, 'train_batch_size': 4, 'train_n_val_patches': None, 'train_tensorboard': True, 'train_reduce_lr': {'factor': 0.5, 'patience': 40, 'min_delta': 0}, 'use_gpu': True}
if use_gpu:
from csbdeep.utils.tf import limit_gpu_memory
limit_gpu_memory(None, allow_growth=True)
# alternatively, adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations
# limit_gpu_memory(0.8)
Note: The trained StarDist2D
model will not predict completed shapes for partially visible objects at the image boundary if train_shape_completion=False
(which is the default option).
model = StarDist2D(conf, name='stardist_multiclass', basedir='models')
Using default values: prob_thresh=0.5, nms_thresh=0.4.
Check if the neural network has a large enough field of view to see up to the boundary of most objects.
median_size = calculate_extents(list(Y), np.median)
fov = np.array(model._axes_tile_overlap('YX'))
print(f"median object size: {median_size}")
print(f"network field of view : {fov}")
if any(median_size > fov):
print("WARNING: median object size larger than field of view of the neural network.")
median object size: [25.25 26. ] network field of view : [94 92]
You can define a function/callable that applies augmentation to each batch of the data generator.
We here use an augmenter
that applies random rotations, flips, and intensity changes, which are typically sensible for (2D) microscopy images (but you can disable augmentation by setting augmenter = None
).
def random_fliprot(img, mask):
assert img.ndim >= mask.ndim
axes = tuple(range(mask.ndim))
perm = tuple(np.random.permutation(axes))
img = img.transpose(perm + tuple(range(mask.ndim, img.ndim)))
mask = mask.transpose(perm)
for ax in axes:
if np.random.rand() > 0.5:
img = np.flip(img, axis=ax)
mask = np.flip(mask, axis=ax)
return img, mask
def random_intensity_change(img):
img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)
return img
def augmenter(x, y):
"""Augmentation of a single input/label image pair.
x is an input image
y is the corresponding ground-truth label image
"""
x, y = random_fliprot(x, y)
x = random_intensity_change(x)
# add some gaussian noise
sig = 0.02*np.random.uniform(0,1)
x = x + sig*np.random.normal(0,1,x.shape)
return x, y
# plot some augmented examples
img, lbl, cls = X[0], Y[0], C[0]
plot_img_label(img,lbl,cls, n_classes=n_classes)
for _ in range(3):
img_aug, lbl_aug = augmenter(img,lbl)
plot_img_label(img_aug,lbl_aug,cls, img_title="image augmented", lbl_title="label augmented", n_classes=n_classes)
None;
We recommend to monitor the progress during training with TensorBoard. You can start it in the shell from the current working directory like this:
$ tensorboard --logdir=.
Then connect to http://localhost:6006/ with your browser.
model.train(X_trn,Y_trn, classes=C_trn, validation_data=(X_val,Y_val,C_val), augmenter=augmenter,
epochs=200) # 200 epochs seem to be enough for synthetic demo dataset
Epoch 1/200 100/100 [==============================] - 15s 147ms/step - loss: 2.6658 - prob_loss: 0.3290 - dist_loss: 8.5931 - prob_class_loss: 0.6182 - prob_kld: 0.2444 - dist_relevant_mae: 8.5925 - dist_relevant_mse: 113.9773 - dist_dist_iou_metric: 0.1735 - val_loss: 1.8868 - val_prob_loss: 0.2487 - val_dist_loss: 5.5821 - val_prob_class_loss: 0.5217 - val_prob_kld: 0.1594 - val_dist_relevant_mae: 5.5810 - val_dist_relevant_mse: 51.2115 - val_dist_dist_iou_metric: 0.4210 Epoch 2/200 100/100 [==============================] - 4s 38ms/step - loss: 1.8076 - prob_loss: 0.2243 - dist_loss: 5.4264 - prob_class_loss: 0.4980 - prob_kld: 0.1398 - dist_relevant_mae: 5.4253 - dist_relevant_mse: 47.6120 - dist_dist_iou_metric: 0.4107 - val_loss: 1.7921 - val_prob_loss: 0.1890 - val_dist_loss: 5.9058 - val_prob_class_loss: 0.4220 - val_prob_kld: 0.0997 - val_dist_relevant_mae: 5.9045 - val_dist_relevant_mse: 58.7034 - val_dist_dist_iou_metric: 0.3477 Epoch 3/200 100/100 [==============================] - 4s 38ms/step - loss: 1.6342 - prob_loss: 0.1717 - dist_loss: 5.4628 - prob_class_loss: 0.3699 - prob_kld: 0.0867 - dist_relevant_mae: 5.4616 - dist_relevant_mse: 48.0222 - dist_dist_iou_metric: 0.4054 - val_loss: 1.4823 - val_prob_loss: 0.1484 - val_dist_loss: 5.3118 - val_prob_class_loss: 0.2715 - val_prob_kld: 0.0591 - val_dist_relevant_mae: 5.3107 - val_dist_relevant_mse: 46.3029 - val_dist_dist_iou_metric: 0.4302 Epoch 4/200 100/100 [==============================] - 4s 38ms/step - loss: 1.3868 - prob_loss: 0.1312 - dist_loss: 5.0995 - prob_class_loss: 0.2357 - prob_kld: 0.0467 - dist_relevant_mae: 5.0983 - dist_relevant_mse: 42.6184 - dist_dist_iou_metric: 0.4424 - val_loss: 1.2444 - val_prob_loss: 0.1257 - val_dist_loss: 4.6132 - val_prob_class_loss: 0.1961 - val_prob_kld: 0.0364 - val_dist_relevant_mae: 4.6122 - val_dist_relevant_mse: 37.3011 - val_dist_dist_iou_metric: 0.5016 Epoch 5/200 100/100 [==============================] - 4s 37ms/step - loss: 1.0489 - prob_loss: 0.1120 - dist_loss: 3.7995 - prob_class_loss: 0.1770 - prob_kld: 0.0277 - dist_relevant_mae: 3.7985 - dist_relevant_mse: 26.9945 - dist_dist_iou_metric: 0.5586 - val_loss: 0.9511 - val_prob_loss: 0.1129 - val_dist_loss: 3.3440 - val_prob_class_loss: 0.1694 - val_prob_kld: 0.0236 - val_dist_relevant_mae: 3.3430 - val_dist_relevant_mse: 22.3795 - val_dist_dist_iou_metric: 0.6379 ... Loading network weights from 'weights_best.h5'.
<tensorflow.python.keras.callbacks.History at 0x7fe827dc4810>
While the default values for the probability and non-maximum suppression thresholds already yield good results in many cases, we still recommend to adapt the thresholds to your data. The optimized threshold values are saved to disk and will be automatically loaded with the model.
model.optimize_thresholds(X_val, Y_val)
NMS threshold = 0.3: 75%|███████▌ | 15/20 [00:01<00:00, 10.81it/s, 0.732 -> 0.951] NMS threshold = 0.4: 75%|███████▌ | 15/20 [00:00<00:00, 16.34it/s, 0.732 -> 0.954] NMS threshold = 0.5: 75%|███████▌ | 15/20 [00:00<00:00, 16.13it/s, 0.732 -> 0.954]
Using optimized values: prob_thresh=0.732477, nms_thresh=0.4. Saving to 'thresholds.json'.
{'prob': 0.7324770065459176, 'nms': 0.4}
Predict for a single example image first.
i = 8
label, res = model.predict_instances(X_val[i], n_tiles=model._guess_n_tiles(X_val[i]))
# the class object ids are stored in the 'results' dict and correspond to the label ids in increasing order
def class_from_res(res):
cls_dict = dict((i+1,c) for i,c in enumerate(res['class_id']))
return cls_dict
print(class_from_res(res))
{1: 1, 2: 1, 3: 2, 4: 2, 5: 1, 6: 1, 7: 1, 8: 2, 9: 1, 10: 2, 11: 1, 12: 2, 13: 1, 14: 1, 15: 1, 16: 2, 17: 2, 18: 2, 19: 1, 20: 2, 21: 2, 22: 1, 23: 2, 24: 2, 25: 2, 26: 1, 27: 2, 28: 1}
plot_img_label(X_val[i], Y_val[i], C_val[i], lbl_title="GT")
plot_img_label(X_val[i], label, class_from_res(res), lbl_title="Pred");
Besides the losses and metrics during training, we can also quantitatively evaluate the actual detection/segmentation performance on the validation data by considering objects in the ground truth to be correctly matched if there are predicted objects with overlap (here intersection over union (IoU)) beyond a chosen IoU threshold $\tau$.
The corresponding matching statistics (average overlap, accuracy, recall, precision, etc.) are typically of greater practical relevance than the losses/metrics computed during training (but harder to formulate as a loss function). The value of $\tau$ can be between 0 (even slightly overlapping objects count as correctly predicted) and 1 (only pixel-perfectly overlapping objects count) and which $\tau$ to use depends on the needed segmentation precision/application.
Please see help(matching)
for definitions of the abbreviations used in the evaluation below and see the Wikipedia page on Sensitivity and specificity for further details.
First predict the labels for all validation images:
Y_val_pred, res_val_pred = tuple(zip(*[model.predict_instances(x, n_tiles=model._guess_n_tiles(x), show_tile_progress=False)
for x in tqdm(X_val[:])]))
100%|██████████| 15/15 [00:00<00:00, 36.30it/s]
Plot another GT/prediction example
i = 10
plot_img_label(X_val[i],Y_val[i], C_val[i], lbl_title="label GT")
plot_img_label(X_val[i],Y_val_pred[i], class_from_res(res_val_pred[i]), lbl_title="label Pred");
Choose several IoU thresholds $\tau$ that might be of interest and for each compute matching statistics for the validation data.
taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
stats = [matching_dataset(Y_val, Y_val_pred, thresh=t, show_progress=False) for t in tqdm(taus)]
100%|██████████| 9/9 [00:00<00:00, 38.59it/s]
Example: Print all available matching statistics for $\tau=0.5$
stats[taus.index(0.5)]
DatasetMatching(criterion='iou', thresh=0.5, fp=2, tp=278, fn=10, precision=0.9928571428571429, recall=0.9652777777777778, accuracy=0.9586206896551724, f1=0.9788732394366197, n_true=288, n_pred=280, mean_true_score=0.8914246526029375, mean_matched_score=0.9234902875886547, panoptic_quality=0.9039799294001619, by_image=False)
Plot the matching statistics and the number of true/false positives/negatives as a function of the IoU threshold $\tau$.
fig, (ax1,ax2) = plt.subplots(1,2, figsize=(15,5))
for m in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):
ax1.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax1.set_xlabel(r'IoU threshold $\tau$')
ax1.set_ylabel('Metric value')
ax1.grid()
ax1.legend()
for m in ('fp', 'tp', 'fn'):
ax2.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax2.set_xlabel(r'IoU threshold $\tau$')
ax2.set_ylabel('Number #')
ax2.grid()
ax2.legend();