Demo: Probabilistic CARE model for denoising of synthetic 2D data

This notebook demonstrates applying a probabilistic CARE model for a 2D denoising task, assuming that training was already completed via 1_training.ipynb.
The trained model is assumed to be located in the folder models with the name my_model.

More documentation is available at

In [1]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from csbdeep.utils import Path, download_and_extract_zip_file, plot_some
from import load_training_data, save_tiff_imagej_compatible
from csbdeep.models import CARE
Using TensorFlow backend.

Download example data

The example data should have been downloaded in 1_training.ipynb.
Just in case, we will download it here again if it's not already present.

In [2]:
download_and_extract_zip_file (
    url       = '',
    targetdir = 'data',
Files found, nothing to download.

- synthetic_disks
- synthetic_disks/data.npz

Load the validation images using during model training.

In [3]:
X_val, Y_val = load_training_data('data/synthetic_disks/data.npz', validation_split=0.1, verbose=True)[1]
number of training images:	 180
number of validation images:	 20
image size (2D):		 (128, 128)
axes:				 SYXC
channels in / out:		 1 / 1

We will apply the trained CARE model here to restore one validation image x (with associated ground truth y).

In [4]:
y = Y_val[2,...,0]
x = X_val[2,...,0]
axes = 'YX'

Input image and associated ground truth

Plot the test image pair.

In [5]:
print('image size =', x.shape)
print('image axes =', axes)

plot_some(np.stack([x,y]), title_list=[['input','target (GT)']]);
image size = (128, 128)
image axes = YX

CARE model

Load trained model (located in base directory models with name my_model) from disk.
The configuration was saved during training and is automatically loaded when CARE is initialized with config=None.

In [6]:
model = CARE(config=None, name='my_model', basedir='models')
Loading network weights from 'weights_best.h5'.

Typical CARE prediction

Predict the restored image as in the non-probabilistic case if you're only interested in a restored image.
But actually, the network returns the expected restored image for the probabilistic network outputs.

Note 1: Since the synthetic image is already normalized, we don't need to do additional normalization.

Note 2: Out of memory problems during model.predict often indicate that the GPU is used by another process. In particular, shut down the training notebook before running the prediction (you may need to restart this notebook).

In [7]:
restored = model.predict(x, axes, normalizer=None)
In [8]:
plot_some(np.stack([x,restored]), title_list=[['input','expected restored image']]);

Save restored image

Save the restored image stack as a ImageJ-compatible TIFF image, i.e. the image can be opened in ImageJ/Fiji with correct axes semantics.

In [9]:
save_tiff_imagej_compatible('results/%s_validation_image.tif' %, restored, axes)

Probabilistic CARE prediction

We now predict the per-pixel Laplace distributions and return an object to work with these.

In [10]:
restored_prob = model.predict_probabilistic(x, axes, normalizer=None)

Plot the mean and scale parameters of the per-pixel Laplace distributions.

In [11]:
plot_some(np.stack([restored_prob.mean(),restored_prob.scale()]), title_list=[['mean','scale']]);

Plot the variance and entropy parameters of the per-pixel Laplace distributions.

In [12]:
plot_some(np.stack([restored_prob.var(),restored_prob.entropy()]), title_list=[['variance','entropy']]);

Sampling restored images

Draw 50 samples of the distribution of the restored image. Plot the first 3 samples.

In [13]:
samples = np.stack(restored_prob.sampling_generator(50))

plot_some(samples[:3], pmin=0.1,pmax=99.9);

Make an animation of the 50 samples.

In [14]:
from matplotlib import animation
from IPython.display import HTML

fig = plt.figure(figsize=(8,8))
im = plt.imshow(samples[0], vmin=np.percentile(samples,0.1), vmax=np.percentile(samples,99.9), cmap='magma')

def updatefig(j):
    return [im]

anim = animation.FuncAnimation(fig, updatefig, frames=len(samples), interval=100)

Once Loop Reflect