Demo: Probabilistic neural network training for denoising of synthetic 2D data

This notebook demonstrates training a probabilistic CARE model for a 2D denoising task, using provided synthetic training data.
Note that training a neural network for actual use should be done on more (representative) data and with more training time.

More documentation is available at http://csbdeep.bioimagecomputing.com/doc/.

In [1]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from csbdeep.utils import download_and_extract_zip_file, axes_dict, plot_some, plot_history
from csbdeep.utils.tf import limit_gpu_memory
from csbdeep.io import load_training_data
from csbdeep.models import Config, CARE
Using TensorFlow backend.

The TensorFlow backend uses all available GPU memory by default, hence it can be useful to limit it:

In [2]:
# limit_gpu_memory(fraction=1/2)

Training data

Download and read provided training data, use 10% as validation data.

In [3]:
download_and_extract_zip_file (
    url       = 'http://csbdeep.bioimagecomputing.com/example_data/synthetic_disks.zip',
    targetdir = 'data',
)
Files missing, downloading... extracting... done.

data:
- synthetic_disks
- synthetic_disks/data.npz
In [4]:
(X,Y), (X_val,Y_val), axes = load_training_data('data/synthetic_disks/data.npz', validation_split=0.1, verbose=True)

c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]
number of training images:	 180
number of validation images:	 20
image size (2D):		 (128, 128)
axes:				 SYXC
channels in / out:		 1 / 1
In [5]:
plt.figure(figsize=(12,5))
plot_some(X_val[:5],Y_val[:5])
plt.suptitle('5 example validation patches (top row: source, bottom row: target)');

CARE model

Before we construct the actual CARE model, we have to define its configuration via a Config object, which includes

  • parameters of the underlying neural network,
  • the learning rate,
  • the number of parameter updates per epoch,
  • the loss function, and
  • whether the model is probabilistic or not.

The defaults should be sensible in many cases, so a change should only be necessary if the training process fails.

For a probabilistic model, we have to explicitly set probabilistic=True.


Important: Note that for this notebook we use a very small number of update steps per epoch for immediate feedback, whereas this number should be increased considerably (e.g. train_steps_per_epoch=400) to obtain a well-trained model.

In [6]:
config = Config(axes, n_channel_in, n_channel_out, probabilistic=True, train_steps_per_epoch=30)
print(config)
vars(config)
Config(axes='YXC', n_channel_in=1, n_channel_out=1, n_dim=2, probabilistic=True, train_batch_size=16, train_checkpoint='weights_best.h5', train_epochs=100, train_learning_rate=0.0004, train_loss='laplace', train_reduce_lr={'factor': 0.5, 'patience': 10, 'min_delta': 0}, train_steps_per_epoch=30, train_tensorboard=True, unet_input_shape=(None, None, 1), unet_kern_size=5, unet_last_activation='linear', unet_n_depth=2, unet_n_first=32, unet_residual=True)
Out[6]:
{'axes': 'YXC',
 'n_channel_in': 1,
 'n_channel_out': 1,
 'n_dim': 2,
 'probabilistic': True,
 'train_batch_size': 16,
 'train_checkpoint': 'weights_best.h5',
 'train_epochs': 100,
 'train_learning_rate': 0.0004,
 'train_loss': 'laplace',
 'train_reduce_lr': {'factor': 0.5, 'min_delta': 0, 'patience': 10},
 'train_steps_per_epoch': 30,
 'train_tensorboard': True,
 'unet_input_shape': (None, None, 1),
 'unet_kern_size': 5,
 'unet_last_activation': 'linear',
 'unet_n_depth': 2,
 'unet_n_first': 32,
 'unet_residual': True}

We now create a CARE model with the chosen configuration:

In [7]:
model = CARE(config, 'my_model', basedir='models')

Training

Training the model will likely take some time. We recommend to monitor the progress with TensorBoard (example below), which allows you to inspect the losses during training. Furthermore, you can look at the predictions for some of the validation images, which can be helpful to recognize problems early on.

You can start TensorBoard from the current working directory with tensorboard --logdir=. Then connect to http://localhost:6006/ with your browser.

In [8]:
history = model.train(X,Y, validation_data=(X_val,Y_val))
Epoch 1/100
30/30 [==============================] - 5s 171ms/step - loss: 0.3453 - mse: 0.0354 - mae: 0.1487 - val_loss: 0.1616 - val_mse: 0.0480 - val_mae: 0.1857
Epoch 2/100
30/30 [==============================] - 2s 83ms/step - loss: -0.1067 - mse: 0.0271 - mae: 0.1285 - val_loss: -0.4445 - val_mse: 0.0185 - val_mae: 0.1070
Epoch 3/100
30/30 [==============================] - 3s 84ms/step - loss: -0.5339 - mse: 0.0169 - mae: 0.1023 - val_loss: -0.6982 - val_mse: 0.0128 - val_mae: 0.0886
Epoch 4/100
30/30 [==============================] - 3s 92ms/step - loss: -0.6778 - mse: 0.0129 - mae: 0.0876 - val_loss: -0.8560 - val_mse: 0.0096 - val_mae: 0.0764
Epoch 5/100
30/30 [==============================] - 3s 92ms/step - loss: -0.9354 - mse: 0.0093 - mae: 0.0711 - val_loss: -1.1605 - val_mse: 0.0069 - val_mae: 0.0592
Epoch 6/100
30/30 [==============================] - 3s 94ms/step - loss: -1.1104 - mse: 0.0082 - mae: 0.0633 - val_loss: -1.2587 - val_mse: 0.0071 - val_mae: 0.0564
Epoch 7/100
30/30 [==============================] - 3s 92ms/step - loss: -1.0999 - mse: 0.0083 - mae: 0.0624 - val_loss: -1.3353 - val_mse: 0.0063 - val_mae: 0.0518
Epoch 8/100
30/30 [==============================] - 3s 93ms/step - loss: -1.2246 - mse: 0.0077 - mae: 0.0593 - val_loss: -1.2563 - val_mse: 0.0066 - val_mae: 0.0572
Epoch 9/100
30/30 [==============================] - 3s 91ms/step - loss: -1.2947 - mse: 0.0072 - mae: 0.0564 - val_loss: -1.5226 - val_mse: 0.0056 - val_mae: 0.0465
Epoch 10/100
30/30 [==============================] - 3s 93ms/step - loss: -1.2939 - mse: 0.0071 - mae: 0.0560 - val_loss: -1.2213 - val_mse: 0.0067 - val_mae: 0.0572
Epoch 11/100
30/30 [==============================] - 3s 93ms/step - loss: -1.2553 - mse: 0.0075 - mae: 0.0561 - val_loss: -1.6160 - val_mse: 0.0059 - val_mae: 0.0467
Epoch 12/100
30/30 [==============================] - 3s 91ms/step - loss: -1.4072 - mse: 0.0064 - mae: 0.0523 - val_loss: -1.7939 - val_mse: 0.0051 - val_mae: 0.0416
Epoch 13/100
30/30 [==============================] - 3s 93ms/step - loss: -1.5931 - mse: 0.0057 - mae: 0.0469 - val_loss: -1.8215 - val_mse: 0.0049 - val_mae: 0.0410
Epoch 14/100
30/30 [==============================] - 3s 92ms/step - loss: -1.4318 - mse: 0.0062 - mae: 0.0506 - val_loss: -1.6916 - val_mse: 0.0060 - val_mae: 0.0445
Epoch 15/100
30/30 [==============================] - 3s 93ms/step - loss: -1.5131 - mse: 0.0059 - mae: 0.0489 - val_loss: -1.8370 - val_mse: 0.0052 - val_mae: 0.0412
Epoch 16/100
30/30 [==============================] - 3s 93ms/step - loss: -1.5136 - mse: 0.0057 - mae: 0.0485 - val_loss: -1.7499 - val_mse: 0.0045 - val_mae: 0.0412
Epoch 17/100
30/30 [==============================] - 3s 92ms/step - loss: -1.5067 - mse: 0.0054 - mae: 0.0476 - val_loss: -1.6025 - val_mse: 0.0045 - val_mae: 0.0429
Epoch 18/100
30/30 [==============================] - 3s 93ms/step - loss: -1.5981 - mse: 0.0052 - mae: 0.0452 - val_loss: -0.9527 - val_mse: 0.0056 - val_mae: 0.0598
Epoch 19/100
30/30 [==============================] - 3s 91ms/step - loss: -1.4761 - mse: 0.0058 - mae: 0.0489 - val_loss: -1.7263 - val_mse: 0.0042 - val_mae: 0.0406
Epoch 20/100
30/30 [==============================] - 3s 93ms/step - loss: -1.5495 - mse: 0.0051 - mae: 0.0459 - val_loss: -1.4908 - val_mse: 0.0047 - val_mae: 0.0478
Epoch 21/100
30/30 [==============================] - 3s 91ms/step - loss: -1.7012 - mse: 0.0047 - mae: 0.0424 - val_loss: -1.6858 - val_mse: 0.0043 - val_mae: 0.0407
Epoch 22/100
30/30 [==============================] - 3s 93ms/step - loss: -1.8789 - mse: 0.0046 - mae: 0.0398 - val_loss: -1.7954 - val_mse: 0.0039 - val_mae: 0.0386
Epoch 23/100
30/30 [==============================] - 3s 91ms/step - loss: -1.7942 - mse: 0.0046 - mae: 0.0406 - val_loss: -1.7726 - val_mse: 0.0040 - val_mae: 0.0388
Epoch 24/100
30/30 [==============================] - 3s 93ms/step - loss: -1.6729 - mse: 0.0048 - mae: 0.0433 - val_loss: -1.5874 - val_mse: 0.0042 - val_mae: 0.0390
Epoch 25/100
30/30 [==============================] - 3s 92ms/step - loss: -1.6964 - mse: 0.0047 - mae: 0.0423 - val_loss: -2.1271 - val_mse: 0.0036 - val_mae: 0.0333
Epoch 26/100
30/30 [==============================] - 3s 92ms/step - loss: -1.9733 - mse: 0.0043 - mae: 0.0377 - val_loss: -2.1992 - val_mse: 0.0036 - val_mae: 0.0326
Epoch 27/100
30/30 [==============================] - 3s 90ms/step - loss: -1.9415 - mse: 0.0042 - mae: 0.0381 - val_loss: -1.4590 - val_mse: 0.0040 - val_mae: 0.0444
Epoch 28/100
30/30 [==============================] - 3s 93ms/step - loss: -1.7322 - mse: 0.0045 - mae: 0.0413 - val_loss: -2.1533 - val_mse: 0.0039 - val_mae: 0.0344
Epoch 29/100
30/30 [==============================] - 3s 92ms/step - loss: -1.9347 - mse: 0.0044 - mae: 0.0382 - val_loss: -1.6743 - val_mse: 0.0038 - val_mae: 0.0402
Epoch 30/100
30/30 [==============================] - 3s 91ms/step - loss: -1.8676 - mse: 0.0043 - mae: 0.0386 - val_loss: -1.4458 - val_mse: 0.0039 - val_mae: 0.0401
Epoch 31/100
30/30 [==============================] - 3s 91ms/step - loss: -1.9600 - mse: 0.0042 - mae: 0.0374 - val_loss: -1.7336 - val_mse: 0.0038 - val_mae: 0.0394
Epoch 32/100
30/30 [==============================] - 3s 94ms/step - loss: -1.8056 - mse: 0.0043 - mae: 0.0396 - val_loss: -1.6303 - val_mse: 0.0038 - val_mae: 0.0404
Epoch 33/100
30/30 [==============================] - 3s 92ms/step - loss: -1.9657 - mse: 0.0042 - mae: 0.0376 - val_loss: -1.4007 - val_mse: 0.0038 - val_mae: 0.0405
Epoch 34/100
30/30 [==============================] - 3s 93ms/step - loss: -1.7242 - mse: 0.0044 - mae: 0.0409 - val_loss: -1.7893 - val_mse: 0.0037 - val_mae: 0.0378
Epoch 35/100
30/30 [==============================] - 3s 91ms/step - loss: -2.0174 - mse: 0.0041 - mae: 0.0367 - val_loss: -2.3093 - val_mse: 0.0035 - val_mae: 0.0313
Epoch 36/100
30/30 [==============================] - 3s 92ms/step - loss: -1.9246 - mse: 0.0041 - mae: 0.0377 - val_loss: -1.8342 - val_mse: 0.0037 - val_mae: 0.0377
Epoch 37/100
30/30 [==============================] - 3s 93ms/step - loss: -2.0086 - mse: 0.0040 - mae: 0.0364 - val_loss: -2.0423 - val_mse: 0.0037 - val_mae: 0.0347
Epoch 38/100
30/30 [==============================] - 3s 91ms/step - loss: -2.0941 - mse: 0.0040 - mae: 0.0357 - val_loss: -2.1635 - val_mse: 0.0033 - val_mae: 0.0323
Epoch 39/100
30/30 [==============================] - 3s 93ms/step - loss: -1.9030 - mse: 0.0041 - mae: 0.0382 - val_loss: -1.6946 - val_mse: 0.0039 - val_mae: 0.0387
Epoch 40/100
30/30 [==============================] - 3s 91ms/step - loss: -1.9785 - mse: 0.0042 - mae: 0.0375 - val_loss: -1.9651 - val_mse: 0.0035 - val_mae: 0.0347
Epoch 41/100
30/30 [==============================] - 3s 91ms/step - loss: -2.1459 - mse: 0.0038 - mae: 0.0346 - val_loss: -2.0064 - val_mse: 0.0033 - val_mae: 0.0339
Epoch 42/100
30/30 [==============================] - 3s 93ms/step - loss: -2.0188 - mse: 0.0039 - mae: 0.0360 - val_loss: -2.0854 - val_mse: 0.0034 - val_mae: 0.0338
Epoch 43/100
30/30 [==============================] - 3s 91ms/step - loss: -2.0330 - mse: 0.0040 - mae: 0.0359 - val_loss: -2.0147 - val_mse: 0.0033 - val_mae: 0.0339
Epoch 44/100
30/30 [==============================] - 3s 92ms/step - loss: -2.1613 - mse: 0.0038 - mae: 0.0344 - val_loss: -2.4779 - val_mse: 0.0031 - val_mae: 0.0289
Epoch 45/100
30/30 [==============================] - 3s 93ms/step - loss: -2.0509 - mse: 0.0038 - mae: 0.0354 - val_loss: -2.1950 - val_mse: 0.0034 - val_mae: 0.0312
Epoch 46/100
30/30 [==============================] - 3s 92ms/step - loss: -1.9307 - mse: 0.0040 - mae: 0.0370 - val_loss: -2.2390 - val_mse: 0.0031 - val_mae: 0.0306
Epoch 47/100
30/30 [==============================] - 3s 93ms/step - loss: -2.0542 - mse: 0.0038 - mae: 0.0353 - val_loss: -2.0041 - val_mse: 0.0032 - val_mae: 0.0339
Epoch 48/100
30/30 [==============================] - 3s 91ms/step - loss: -2.1745 - mse: 0.0038 - mae: 0.0340 - val_loss: -2.1843 - val_mse: 0.0032 - val_mae: 0.0318
Epoch 49/100
30/30 [==============================] - 3s 92ms/step - loss: -2.2407 - mse: 0.0036 - mae: 0.0331 - val_loss: -2.0264 - val_mse: 0.0032 - val_mae: 0.0328
Epoch 50/100
30/30 [==============================] - 3s 91ms/step - loss: -2.2526 - mse: 0.0036 - mae: 0.0326 - val_loss: -2.1618 - val_mse: 0.0032 - val_mae: 0.0320
Epoch 51/100
30/30 [==============================] - 3s 92ms/step - loss: -2.2659 - mse: 0.0036 - mae: 0.0324 - val_loss: -2.5495 - val_mse: 0.0030 - val_mae: 0.0282
Epoch 52/100
30/30 [==============================] - 3s 91ms/step - loss: -1.7608 - mse: 0.0042 - mae: 0.0391 - val_loss: -2.1892 - val_mse: 0.0037 - val_mae: 0.0325
Epoch 53/100
30/30 [==============================] - 3s 93ms/step - loss: -1.9775 - mse: 0.0043 - mae: 0.0372 - val_loss: -2.1562 - val_mse: 0.0033 - val_mae: 0.0319
Epoch 54/100
30/30 [==============================] - 3s 92ms/step - loss: -2.0836 - mse: 0.0038 - mae: 0.0347 - val_loss: -2.1415 - val_mse: 0.0033 - val_mae: 0.0326
Epoch 55/100
30/30 [==============================] - 3s 90ms/step - loss: -1.9328 - mse: 0.0038 - mae: 0.0366 - val_loss: -2.3708 - val_mse: 0.0031 - val_mae: 0.0293
Epoch 56/100
30/30 [==============================] - 3s 92ms/step - loss: -2.0131 - mse: 0.0038 - mae: 0.0352 - val_loss: -2.0476 - val_mse: 0.0032 - val_mae: 0.0328
Epoch 57/100
30/30 [==============================] - 3s 92ms/step - loss: -1.8980 - mse: 0.0040 - mae: 0.0371 - val_loss: -2.1340 - val_mse: 0.0033 - val_mae: 0.0316
Epoch 58/100
30/30 [==============================] - 3s 91ms/step - loss: -1.9146 - mse: 0.0039 - mae: 0.0366 - val_loss: -2.1744 - val_mse: 0.0033 - val_mae: 0.0321
Epoch 59/100
30/30 [==============================] - 3s 90ms/step - loss: -2.2226 - mse: 0.0036 - mae: 0.0329 - val_loss: -2.3691 - val_mse: 0.0031 - val_mae: 0.0298
Epoch 60/100
30/30 [==============================] - 3s 92ms/step - loss: -2.0925 - mse: 0.0036 - mae: 0.0341 - val_loss: -1.9231 - val_mse: 0.0033 - val_mae: 0.0336
Epoch 61/100
30/30 [==============================] - 3s 91ms/step - loss: -2.1000 - mse: 0.0035 - mae: 0.0335 - val_loss: -2.4947 - val_mse: 0.0032 - val_mae: 0.0288

Epoch 00061: ReduceLROnPlateau reducing learning rate to 0.00019999999494757503.
Epoch 62/100
30/30 [==============================] - 3s 94ms/step - loss: -2.4604 - mse: 0.0035 - mae: 0.0307 - val_loss: -2.6371 - val_mse: 0.0031 - val_mae: 0.0282
Epoch 63/100
30/30 [==============================] - 3s 91ms/step - loss: -2.5082 - mse: 0.0035 - mae: 0.0309 - val_loss: -2.6796 - val_mse: 0.0030 - val_mae: 0.0278
Epoch 64/100
30/30 [==============================] - 3s 93ms/step - loss: -2.4813 - mse: 0.0034 - mae: 0.0306 - val_loss: -2.6315 - val_mse: 0.0030 - val_mae: 0.0279
Epoch 65/100
30/30 [==============================] - 3s 93ms/step - loss: -2.4764 - mse: 0.0034 - mae: 0.0307 - val_loss: -2.5102 - val_mse: 0.0029 - val_mae: 0.0282
Epoch 66/100
30/30 [==============================] - 3s 91ms/step - loss: -2.4258 - mse: 0.0034 - mae: 0.0309 - val_loss: -2.4340 - val_mse: 0.0030 - val_mae: 0.0286
Epoch 67/100
30/30 [==============================] - 3s 94ms/step - loss: -2.4416 - mse: 0.0035 - mae: 0.0310 - val_loss: -2.7346 - val_mse: 0.0030 - val_mae: 0.0274
Epoch 68/100
30/30 [==============================] - 3s 92ms/step - loss: -2.1523 - mse: 0.0037 - mae: 0.0336 - val_loss: -2.2170 - val_mse: 0.0031 - val_mae: 0.0306
Epoch 69/100
30/30 [==============================] - 3s 93ms/step - loss: -2.3574 - mse: 0.0035 - mae: 0.0317 - val_loss: -2.7385 - val_mse: 0.0031 - val_mae: 0.0276
Epoch 70/100
30/30 [==============================] - 3s 91ms/step - loss: -2.1810 - mse: 0.0036 - mae: 0.0331 - val_loss: -2.4213 - val_mse: 0.0031 - val_mae: 0.0288
Epoch 71/100
30/30 [==============================] - 3s 91ms/step - loss: -2.3522 - mse: 0.0036 - mae: 0.0320 - val_loss: -2.3553 - val_mse: 0.0031 - val_mae: 0.0298
Epoch 72/100
30/30 [==============================] - 3s 92ms/step - loss: -2.3519 - mse: 0.0035 - mae: 0.0318 - val_loss: -2.7409 - val_mse: 0.0030 - val_mae: 0.0271
Epoch 73/100
30/30 [==============================] - 3s 93ms/step - loss: -2.5508 - mse: 0.0034 - mae: 0.0302 - val_loss: -2.2927 - val_mse: 0.0029 - val_mae: 0.0288
Epoch 74/100
30/30 [==============================] - 3s 92ms/step - loss: -2.2971 - mse: 0.0034 - mae: 0.0317 - val_loss: -2.5308 - val_mse: 0.0029 - val_mae: 0.0274
Epoch 75/100
30/30 [==============================] - 3s 92ms/step - loss: -2.4585 - mse: 0.0034 - mae: 0.0306 - val_loss: -2.6902 - val_mse: 0.0029 - val_mae: 0.0271
Epoch 76/100
30/30 [==============================] - 3s 93ms/step - loss: -2.3462 - mse: 0.0034 - mae: 0.0314 - val_loss: -2.3474 - val_mse: 0.0030 - val_mae: 0.0282
Epoch 77/100
30/30 [==============================] - 3s 92ms/step - loss: -2.4090 - mse: 0.0034 - mae: 0.0308 - val_loss: -2.8157 - val_mse: 0.0029 - val_mae: 0.0266
Epoch 78/100
30/30 [==============================] - 3s 93ms/step - loss: -2.5741 - mse: 0.0034 - mae: 0.0299 - val_loss: -2.7968 - val_mse: 0.0029 - val_mae: 0.0266
Epoch 79/100
30/30 [==============================] - 3s 93ms/step - loss: -2.3696 - mse: 0.0035 - mae: 0.0316 - val_loss: -1.9409 - val_mse: 0.0033 - val_mae: 0.0348
Epoch 80/100
30/30 [==============================] - 3s 91ms/step - loss: -2.3380 - mse: 0.0035 - mae: 0.0316 - val_loss: -2.3417 - val_mse: 0.0029 - val_mae: 0.0289
Epoch 81/100
30/30 [==============================] - 3s 91ms/step - loss: -2.5185 - mse: 0.0033 - mae: 0.0300 - val_loss: -2.7727 - val_mse: 0.0029 - val_mae: 0.0264
Epoch 82/100
30/30 [==============================] - 3s 93ms/step - loss: -2.5583 - mse: 0.0033 - mae: 0.0296 - val_loss: -2.0936 - val_mse: 0.0029 - val_mae: 0.0306
Epoch 83/100
30/30 [==============================] - 3s 91ms/step - loss: -2.4964 - mse: 0.0034 - mae: 0.0300 - val_loss: -2.6966 - val_mse: 0.0028 - val_mae: 0.0267
Epoch 84/100
30/30 [==============================] - 3s 93ms/step - loss: -2.5968 - mse: 0.0033 - mae: 0.0294 - val_loss: -2.8247 - val_mse: 0.0029 - val_mae: 0.0262
Epoch 85/100
30/30 [==============================] - 3s 93ms/step - loss: -2.4245 - mse: 0.0034 - mae: 0.0306 - val_loss: -2.8501 - val_mse: 0.0029 - val_mae: 0.0262
Epoch 86/100
30/30 [==============================] - 3s 92ms/step - loss: -2.5836 - mse: 0.0033 - mae: 0.0296 - val_loss: -2.8578 - val_mse: 0.0029 - val_mae: 0.0260
Epoch 87/100
30/30 [==============================] - 3s 92ms/step - loss: -2.5017 - mse: 0.0034 - mae: 0.0303 - val_loss: -1.7773 - val_mse: 0.0033 - val_mae: 0.0335
Epoch 88/100
30/30 [==============================] - 3s 92ms/step - loss: -2.4787 - mse: 0.0034 - mae: 0.0305 - val_loss: -2.7992 - val_mse: 0.0029 - val_mae: 0.0265
Epoch 89/100
30/30 [==============================] - 3s 92ms/step - loss: -2.6583 - mse: 0.0032 - mae: 0.0289 - val_loss: -2.8498 - val_mse: 0.0030 - val_mae: 0.0267
Epoch 90/100
30/30 [==============================] - 3s 93ms/step - loss: -2.6253 - mse: 0.0033 - mae: 0.0294 - val_loss: -2.6858 - val_mse: 0.0028 - val_mae: 0.0266
Epoch 91/100
30/30 [==============================] - 3s 93ms/step - loss: -2.6744 - mse: 0.0033 - mae: 0.0289 - val_loss: -2.7216 - val_mse: 0.0028 - val_mae: 0.0266
Epoch 92/100
30/30 [==============================] - 3s 92ms/step - loss: -2.3575 - mse: 0.0034 - mae: 0.0317 - val_loss: -2.1316 - val_mse: 0.0033 - val_mae: 0.0313
Epoch 93/100
30/30 [==============================] - 3s 93ms/step - loss: -2.4122 - mse: 0.0035 - mae: 0.0312 - val_loss: -2.2811 - val_mse: 0.0029 - val_mae: 0.0297
Epoch 94/100
30/30 [==============================] - 3s 91ms/step - loss: -2.2937 - mse: 0.0034 - mae: 0.0316 - val_loss: -2.2637 - val_mse: 0.0029 - val_mae: 0.0296
Epoch 95/100
30/30 [==============================] - 3s 94ms/step - loss: -2.6105 - mse: 0.0034 - mae: 0.0296 - val_loss: -2.8713 - val_mse: 0.0028 - val_mae: 0.0261
Epoch 96/100
30/30 [==============================] - 3s 92ms/step - loss: -2.5759 - mse: 0.0033 - mae: 0.0295 - val_loss: -2.5002 - val_mse: 0.0028 - val_mae: 0.0276
Epoch 97/100
30/30 [==============================] - 3s 93ms/step - loss: -2.4719 - mse: 0.0033 - mae: 0.0300 - val_loss: -2.8232 - val_mse: 0.0028 - val_mae: 0.0258
Epoch 98/100
30/30 [==============================] - 3s 92ms/step - loss: -2.2927 - mse: 0.0034 - mae: 0.0318 - val_loss: -2.5178 - val_mse: 0.0029 - val_mae: 0.0279
Epoch 99/100
30/30 [==============================] - 3s 91ms/step - loss: -2.3909 - mse: 0.0034 - mae: 0.0308 - val_loss: -2.4594 - val_mse: 0.0029 - val_mae: 0.0283
Epoch 100/100
30/30 [==============================] - 3s 92ms/step - loss: -2.6732 - mse: 0.0032 - mae: 0.0286 - val_loss: -2.7694 - val_mse: 0.0028 - val_mae: 0.0263

Loading network weights from 'weights_best.h5'.

Plot final training history (available in TensorBoard during training):

In [9]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);
['loss', 'lr', 'mae', 'mse', 'val_loss', 'val_mae', 'val_mse']

Evaluation

Example results for validation images.

In [10]:
plt.figure(figsize=(12,10))
_P = model.keras_model.predict(X_val[:5])
_P_mean  = _P[...,:(_P.shape[-1]//2)]
_P_scale = _P[...,(_P.shape[-1]//2):]
plot_some(X_val[:5],Y_val[:5],_P_mean,_P_scale,pmax=99.5)
plt.suptitle('5 example validation patches\n'      
             'first row: input (source),  '        
             'second row: target (ground truth),  '
             'third row: predicted Laplace mean,  '
             'forth row: predicted Laplace scale');