#!/usr/bin/env python
# coding: utf-8
#
#
# # Demo: Neural network training for combined denoising and upsamling of synthetic 3D data
#
# This notebook demonstrates training a CARE model for a combined denoising and upsampling task, assuming that training data was already generated via [1_datagen.ipynb](1_datagen.ipynb) and has been saved to disk to the file ``data/my_training_data.npz``. Note that the training approach is exactly the same as in the standard CARE approach, what differs is the [training data generation](1_datagen.ipynb) and [prediction](3_prediction.ipynb).
#
# Note that training a neural network for actual use should be done on more (representative) data and with more training time.
#
# More documentation is available at http://csbdeep.bioimagecomputing.com/doc/.
# In[1]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
get_ipython().run_line_magic('matplotlib', 'inline')
get_ipython().run_line_magic('config', "InlineBackend.figure_format = 'retina'")
from tifffile import imread
from csbdeep.utils import axes_dict, plot_some, plot_history
from csbdeep.utils.tf import limit_gpu_memory
from csbdeep.io import load_training_data
from csbdeep.models import Config, UpsamplingCARE
# The TensorFlow backend uses all available GPU memory by default, hence it can be useful to limit it:
# In[2]:
# limit_gpu_memory(fraction=1/2)
#
#
# # Training data
#
# Load training data generated via [1_datagen.ipynb](1_datagen.ipynb), use 10% as validation data.
# In[3]:
(X,Y), (X_val,Y_val), axes = load_training_data('data/my_training_data.npz', validation_split=0.1, verbose=True)
c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]
# In[4]:
plt.figure(figsize=(12,3))
plot_some(X_val[:5,...,0,0],Y_val[:5,...,0,0])
plt.suptitle('5 example validation patches (ZY slice, top row: source, bottom row: target)');
#
#
# # CARE model
#
# Before we construct the actual CARE model, we have to define its configuration via a `Config` object, which includes
# * parameters of the underlying neural network,
# * the learning rate,
# * the number of parameter updates per epoch,
# * the loss function, and
# * whether the model is probabilistic or not.
#
# The defaults should be sensible in many cases, so a change should only be necessary if the training process fails.
#
# ---
#
# Important: Note that for this notebook we use a very small number of update steps per epoch for immediate feedback, whereas this number should be increased considerably (e.g. `train_steps_per_epoch=400`, `train_batch_size=16`) to obtain a well-trained model.
# In[5]:
config = Config(axes, n_channel_in, n_channel_out, train_steps_per_epoch=25, train_batch_size=4)
print(config)
vars(config)
# We now create an upsampling CARE model with the chosen configuration:
# In[6]:
model = UpsamplingCARE(config, 'my_model', basedir='models')
#
#
# # Training
#
# Training the model will likely take some time. We recommend to monitor the progress with [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard) (example below), which allows you to inspect the losses during training.
# Furthermore, you can look at the predictions for some of the validation images, which can be helpful to recognize problems early on.
#
# You can start TensorBoard from the current working directory with `tensorboard --logdir=.`
# Then connect to [http://localhost:6006/](http://localhost:6006/) with your browser.
#
# ![](http://csbdeep.bioimagecomputing.com/old/img/tensorboard_upsampling3D.png)
# In[7]:
history = model.train(X,Y, validation_data=(X_val,Y_val))
# Plot final training history (available in TensorBoard during training):
# In[8]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);
#
#
# # Evaluation
#
# Example results for validation images.
# In[9]:
plt.figure(figsize=(12,4.5))
_P = model.keras_model.predict(X_val[:5])
if config.probabilistic:
_P = _P[...,:(_P.shape[-1]//2)]
plot_some(X_val[:5,...,0,0],Y_val[:5,...,0,0],_P[...,0,0],pmax=99.5)
plt.suptitle('5 example validation patches (ZY slice)\n'
'top row: input (source), '
'middle row: target (ground truth), '
'bottom row: predicted from source');
#
#
# # Export model to be used with CSBDeep **Fiji** plugins and **KNIME** workflows
#
# See https://github.com/CSBDeep/CSBDeep_website/wiki/Your-Model-in-Fiji for details.
# In[10]:
model.export_TF()