#!/usr/bin/env python # coding: utf-8 #
# # # Demo: Probabilistic neural network training for denoising of synthetic 2D data # # This notebook demonstrates training a probabilistic CARE model for a 2D denoising task, using provided synthetic training data. # Note that training a neural network for actual use should be done on more (representative) data and with more training time. # # More documentation is available at http://csbdeep.bioimagecomputing.com/doc/. # In[1]: from __future__ import print_function, unicode_literals, absolute_import, division import numpy as np import matplotlib.pyplot as plt get_ipython().run_line_magic('matplotlib', 'inline') get_ipython().run_line_magic('config', "InlineBackend.figure_format = 'retina'") from tifffile import imread from csbdeep.utils import download_and_extract_zip_file, axes_dict, plot_some, plot_history from csbdeep.utils.tf import limit_gpu_memory from csbdeep.io import load_training_data from csbdeep.models import Config, CARE # The TensorFlow backend uses all available GPU memory by default, hence it can be useful to limit it: # In[2]: # limit_gpu_memory(fraction=1/2) #
# # # Training data # # Download and read provided training data, use 10% as validation data. # In[3]: download_and_extract_zip_file ( url = 'http://csbdeep.bioimagecomputing.com/example_data/synthetic_disks.zip', targetdir = 'data', ) # In[4]: (X,Y), (X_val,Y_val), axes = load_training_data('data/synthetic_disks/data.npz', validation_split=0.1, verbose=True) c = axes_dict(axes)['C'] n_channel_in, n_channel_out = X.shape[c], Y.shape[c] # In[5]: plt.figure(figsize=(12,5)) plot_some(X_val[:5],Y_val[:5]) plt.suptitle('5 example validation patches (top row: source, bottom row: target)'); #
# # # CARE model # # Before we construct the actual CARE model, we have to define its configuration via a `Config` object, which includes # * parameters of the underlying neural network, # * the learning rate, # * the number of parameter updates per epoch, # * the loss function, and # * whether the model is probabilistic or not. # # The defaults should be sensible in many cases, so a change should only be necessary if the training process fails. # # For a probabilistic model, we have to explicitly set `probabilistic=True`. # # --- # # Important: Note that for this notebook we use a very small number of update steps per epoch for immediate feedback, whereas this number should be increased considerably (e.g. `train_steps_per_epoch=400`) to obtain a well-trained model. # In[6]: config = Config(axes, n_channel_in, n_channel_out, probabilistic=True, train_steps_per_epoch=30) print(config) vars(config) # We now create a CARE model with the chosen configuration: # In[7]: model = CARE(config, 'my_model', basedir='models') #
# # # Training # # Training the model will likely take some time. We recommend to monitor the progress with [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard) (example below), which allows you to inspect the losses during training. # Furthermore, you can look at the predictions for some of the validation images, which can be helpful to recognize problems early on. # # You can start TensorBoard from the current working directory with `tensorboard --logdir=.` # Then connect to [http://localhost:6006/](http://localhost:6006/) with your browser. # # ![](http://csbdeep.bioimagecomputing.com/old/img/tensorboard_denoising2D_probabilistic.png) # In[8]: history = model.train(X,Y, validation_data=(X_val,Y_val)) # Plot final training history (available in TensorBoard during training): # In[9]: print(sorted(list(history.history.keys()))) plt.figure(figsize=(16,5)) plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']); #
# # # Evaluation # # Example results for validation images. # In[10]: plt.figure(figsize=(12,10)) _P = model.keras_model.predict(X_val[:5]) _P_mean = _P[...,:(_P.shape[-1]//2)] _P_scale = _P[...,(_P.shape[-1]//2):] plot_some(X_val[:5],Y_val[:5],_P_mean,_P_scale,pmax=99.5) plt.suptitle('5 example validation patches\n' 'first row: input (source), ' 'second row: target (ground truth), ' 'third row: predicted Laplace mean, ' 'forth row: predicted Laplace scale'); #
# # # Export model to be used with CSBDeep **Fiji** plugins and **KNIME** workflows # # See https://github.com/CSBDeep/CSBDeep_website/wiki/Your-Model-in-Fiji for details. # In[11]: model.export_TF()