# Demo: Probabilistic CARE model for denoising of synthetic 2D data¶

This notebook demonstrates applying a probabilistic CARE model for a 2D denoising task, assuming that training was already completed via 1_training.ipynb.
The trained model is assumed to be located in the folder models with the name my_model.

More documentation is available at http://csbdeep.bioimagecomputing.com/doc/.

In [1]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from csbdeep.models import CARE

Using TensorFlow backend.


In [2]:
download_and_extract_zip_file (
url       = 'http://csbdeep.bioimagecomputing.com/example_data/synthetic_disks.zip',
targetdir = 'data',
)

Files found, nothing to download.

data:
- synthetic_disks
- synthetic_disks/data.npz


Load the validation images using during model training.

In [3]:
X_val, Y_val = load_training_data('data/synthetic_disks/data.npz', validation_split=0.1, verbose=True)[1]

number of training images:	 180
number of validation images:	 20
image size (2D):		 (128, 128)
axes:				 SYXC
channels in / out:		 1 / 1


We will apply the trained CARE model here to restore one validation image x (with associated ground truth y).

In [4]:
y = Y_val[2,...,0]
x = X_val[2,...,0]
axes = 'YX'


# Input image and associated ground truth¶

Plot the test image pair.

In [5]:
print('image size =', x.shape)
print('image axes =', axes)

plt.figure(figsize=(16,10))
plot_some(np.stack([x,y]), title_list=[['input','target (GT)']]);

image size = (128, 128)
image axes = YX


# CARE model¶

Load trained model (located in base directory models with name my_model) from disk.
The configuration was saved during training and is automatically loaded when CARE is initialized with config=None.

In [6]:
model = CARE(config=None, name='my_model', basedir='models')

Loading network weights from 'weights_best.h5'.


## Typical CARE prediction¶

Predict the restored image as in the non-probabilistic case if you're only interested in a restored image.
But actually, the network returns the expected restored image for the probabilistic network outputs.

Note 1: Since the synthetic image is already normalized, we don't need to do additional normalization.

Note 2: Out of memory problems during model.predict often indicate that the GPU is used by another process. In particular, shut down the training notebook before running the prediction (you may need to restart this notebook).

In [7]:
restored = model.predict(x, axes, normalizer=None)

In [8]:
plt.figure(figsize=(16,10))
plot_some(np.stack([x,restored]), title_list=[['input','expected restored image']]);


### Save restored image¶

Save the restored image stack as a ImageJ-compatible TIFF image, i.e. the image can be opened in ImageJ/Fiji with correct axes semantics.

In [9]:
Path('results').mkdir(exist_ok=True)
save_tiff_imagej_compatible('results/%s_validation_image.tif' % model.name, restored, axes)


# Probabilistic CARE prediction¶

We now predict the per-pixel Laplace distributions and return an object to work with these.

In [10]:
restored_prob = model.predict_probabilistic(x, axes, normalizer=None)


Plot the mean and scale parameters of the per-pixel Laplace distributions.

In [11]:
plt.figure(figsize=(16,10))
plot_some(np.stack([restored_prob.mean(),restored_prob.scale()]), title_list=[['mean','scale']]);


Plot the variance and entropy parameters of the per-pixel Laplace distributions.

In [12]:
plt.figure(figsize=(16,10))
plot_some(np.stack([restored_prob.var(),restored_prob.entropy()]), title_list=[['variance','entropy']]);


## Sampling restored images¶

Draw 50 samples of the distribution of the restored image. Plot the first 3 samples.

In [13]:
samples = np.stack(restored_prob.sampling_generator(50))

plt.figure(figsize=(16,5))
plot_some(samples[:3], pmin=0.1,pmax=99.9);


Make an animation of the 50 samples.

In [14]:
from matplotlib import animation
from IPython.display import HTML

fig = plt.figure(figsize=(8,8))
im = plt.imshow(samples[0], vmin=np.percentile(samples,0.1), vmax=np.percentile(samples,99.9), cmap='magma')
plt.close()

def updatefig(j):
im.set_array(samples[j])
return [im]

anim = animation.FuncAnimation(fig, updatefig, frames=len(samples), interval=100)
HTML(anim.to_jshtml())

Out[14]:

Once Loop Reflect