#!/usr/bin/env python
# coding: utf-8
#
#
# # Demo: Apply trained CARE model for combined denoising and upsamling of synthetic 3D data
#
# This notebook demonstrates applying a CARE model for a combined denoising and upsampling task, assuming that training was already completed via [2_training.ipynb](2_training.ipynb). The trained model is assumed to be located in the folder `models` with the name `my_model`.
#
# More documentation is available at http://csbdeep.bioimagecomputing.com/doc/.
# In[1]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
get_ipython().run_line_magic('matplotlib', 'inline')
get_ipython().run_line_magic('config', "InlineBackend.figure_format = 'retina'")
from tifffile import imread
from csbdeep.utils import Path, download_and_extract_zip_file, plot_some
from csbdeep.io import save_tiff_imagej_compatible
from csbdeep.models import UpsamplingCARE
#
#
# # Download example data
#
# The example data (also for testing) should have been downloaded in [1_datagen.ipynb](1_datagen.ipynb).
# Just in case, we will download it here again if it's not already present.
# In[2]:
download_and_extract_zip_file (
url = 'http://csbdeep.bioimagecomputing.com/example_data/synthetic_upsampling.zip',
targetdir = 'data',
)
#
#
# # Raw 3D image stack with low axial resolution
#
# We plot XY and ZY slices of the stack and define the image axes and subsampling factor, which will be needed later for prediction.
# In[3]:
x = imread('data/synthetic_upsampling/test_stacks_sub_4/stack_low_sub_4_03.tif')
axes = 'ZYX'
subsample = 4
print('image size =', x.shape)
print('image axes =', axes)
print('subsample factor =', subsample)
plt.figure(figsize=(12,12))
plt.imshow(x[20],cmap='magma')
plt.title('XY slice')
plt.axis('off')
plt.figure(figsize=(12,12))
plt.imshow(x[...,20],aspect=subsample,cmap='magma')
plt.title('ZY slice')
plt.axis('off')
None;
#
#
# # Upsampling CARE model
#
# Load trained model (located in base directory `models` with name `my_model`) from disk.
# The configuration was saved during training and is automatically loaded when `UpsamplingCARE` is initialized with `config=None`.
# In[4]:
model = UpsamplingCARE(config=None, name='my_model', basedir='models')
# ## Apply CARE network to raw image
#
# Predict the restored image (image will be successively split into smaller tiles if there are memory issues).
# We directly increase `n_tiles`, since this is a relatively large image stack.
#
# **Important:** You need to supply the subsampling factor, which must be the same as used during [training data generation](1_datagen.ipynb).
#
# **Note**: *Out of memory* problems during `model.predict` can also indicate that the GPU is used by another process. In particular, shut down the training notebook before running the prediction (you may need to restart this notebook).
# In[5]:
get_ipython().run_cell_magic('time', '', "\nrestored = model.predict(x, axes, subsample, n_tiles=(2,2,2))\n\nprint('input size =', x.shape)\nprint('output size =', restored.shape)\nprint()\n")
# ## Save reconstructed image
#
# Save the reconstructed image stack as a ImageJ-compatible TIFF image, i.e. the image can be opened in ImageJ/Fiji with correct axes semantics.
# In[6]:
Path('results').mkdir(exist_ok=True)
save_tiff_imagej_compatible('results/%s_restored_stack_low_sub_4_03.tif' % model.name, restored, axes)
#
#
# # Upsampled (and denoised) image via CARE network
# In[7]:
plt.figure(figsize=(12,12))
plt.imshow(restored[subsample*20],cmap='magma')
plt.title('XY slice')
plt.axis('off')
plt.figure(figsize=(12,12))
plt.imshow(restored[...,20],cmap='magma')
plt.title('ZY slice')
plt.axis('off')
None;