You can either train your model or upload a pre-trained one from: ../models_weights/03.03,14:19,U-Net light, on RIM-ONE v3 256 px fold 0, SGD, high augm, CLAHE, log_dice loss/last_checkpoint.hdf5.
__versio
%load_ext autoreload
%autoreload 2
import os
import glob
from datetime import datetime
#import warnings
#warnings.simplefilter('ignore')
import scipy as sp
import scipy.ndimage
import numpy as np
import pandas as pd
import tensorflow as tf
import skimage
import skimage.exposure
import mahotas as mh
from sklearn.model_selection import KFold
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
import h5py
from tqdm import tqdm_notebook
from IPython.display import display
from dual_IDG import DualImageDataGenerator
Using TensorFlow backend.
import keras
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, Activation, Flatten, BatchNormalization, \
Conv2D, MaxPooling2D, ZeroPadding2D, Input, Embedding, \
Lambda, UpSampling2D, Cropping2D, Concatenate
from keras.utils import np_utils
from keras.optimizers import SGD, Adam
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau, CSVLogger
from keras.preprocessing.image import ImageDataGenerator
from keras import backend as K
print('Keras version:', keras.__version__)
print('TensorFlow version:', tf.__version__)
Keras version: 2.3.1 TensorFlow version: 2.0.0
K.set_image_data_format('channels_first')
def mean_IOU_gpu(X, Y):
"""Computes mean Intersection-over-Union (IOU) for two arrays of binary images.
Assuming X and Y are of shape (n_images, w, h)."""
#X_fl = K.clip(K.batch_flatten(X), K.epsilon(), 1.)
#Y_fl = K.clip(K.batch_flatten(Y), K.epsilon(), 1.)
X_fl = K.clip(K.batch_flatten(X), 0., 1.)
Y_fl = K.clip(K.batch_flatten(Y), 0., 1.)
X_fl = K.cast(K.greater(X_fl, 0.5), 'float32')
Y_fl = K.cast(K.greater(Y_fl, 0.5), 'float32')
intersection = K.sum(X_fl * Y_fl, axis=1)
union = K.sum(K.maximum(X_fl, Y_fl), axis=1)
# if union == 0, it follows that intersection == 0 => score should be 0.
union = K.switch(K.equal(union, 0), K.ones_like(union), union)
return K.mean(intersection / K.cast(union, 'float32'))
def mean_IOU_gpu_loss(X, Y):
return -mean_IOU_gpu(X, Y)
def dice(y_true, y_pred):
# Workaround for shape bug. For some reason y_true shape was not being set correctly
#y_true.set_shape(y_pred.get_shape())
# Without K.clip, K.sum() behaves differently when compared to np.count_nonzero()
#y_true_f = K.clip(K.batch_flatten(y_true), K.epsilon(), 1.)
#y_pred_f = K.clip(K.batch_flatten(y_pred), K.epsilon(), 1.)
y_true_f = K.clip(K.batch_flatten(y_true), 0., 1.)
y_pred_f = K.clip(K.batch_flatten(y_pred), 0., 1.)
#y_pred_f = K.greater(y_pred_f, 0.5)
intersection = 2 * K.sum(y_true_f * y_pred_f, axis=1)
union = K.sum(y_true_f * y_true_f, axis=1) + K.sum(y_pred_f * y_pred_f, axis=1)
return K.mean(intersection / union)
def dice_loss(y_true, y_pred):
return -dice(y_true, y_pred)
def log_dice_loss(y_true, y_pred):
return -K.log(dice(y_true, y_pred))
def dice_metric(y_true, y_pred):
"""An exact Dice score for binary tensors."""
y_true_f = K.cast(K.greater(y_true, 0.5), 'float32')
y_pred_f = K.cast(K.greater(y_pred, 0.5), 'float32')
return dice(y_true_f, y_pred_f)
def tf_to_th_encoding(X):
return np.rollaxis(X, 3, 1)
def th_to_tf_encoding(X):
return np.rollaxis(X, 1, 4)
#h5f = h5py.File(os.path.join(os.path.dirname(os.getcwd()), 'data', 'hdf5_datasets', 'all_data.hdf5'), 'r')
h5f = h5py.File(os.path.join(os.path.dirname(os.getcwd()), 'data', 'hdf5_datasets', 'RIM_ONE_v3.hdf5'), 'r')
def get_unet_light(img_rows=256, img_cols=256):
inputs = Input((3, img_rows, img_cols))
conv1 = Conv2D(32, kernel_size=3, activation='relu', padding='same')(inputs)
conv1 = Dropout(0.3)(conv1)
conv1 = Conv2D(32, kernel_size=3, activation='relu', padding='same')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(pool1)
conv2 = Dropout(0.3)(conv2)
conv2 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(pool2)
conv3 = Dropout(0.3)(conv3)
conv3 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(pool3)
conv4 = Dropout(0.3)(conv4)
conv4 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(pool4)
conv5 = Dropout(0.3)(conv5)
conv5 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(conv5)
up6 = Concatenate(axis=1)([UpSampling2D(size=(2, 2))(conv5), conv4])
conv6 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(up6)
conv6 = Dropout(0.3)(conv6)
conv6 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(conv6)
up7 = Concatenate(axis=1)([UpSampling2D(size=(2, 2))(conv6), conv3])
conv7 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(up7)
conv7 = Dropout(0.3)(conv7)
conv7 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(conv7)
up8 = Concatenate(axis=1)([UpSampling2D(size=(2, 2))(conv7), conv2])
conv8 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(up8)
conv8 = Dropout(0.3)(conv8)
conv8 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(conv8)
up9 = Concatenate(axis=1)([UpSampling2D(size=(2, 2))(conv8), conv1])
conv9 = Conv2D(32, kernel_size=3, activation='relu', padding='same')(up9)
conv9 = Dropout(0.3)(conv9)
conv9 = Conv2D(32, kernel_size=3, activation='relu', padding='same')(conv9)
conv10 = Conv2D(1, kernel_size=1, activation='sigmoid', padding='same')(conv9)
#conv10 = Flatten()(conv10)
model = Model(input=inputs, output=conv10)
return model
model = get_unet_light(img_rows=256, img_cols=256)
model.compile(optimizer=SGD(lr=1e-3, momentum=0.95),
loss=log_dice_loss,
metrics=[mean_IOU_gpu, dice_metric])
model.summary()
Model: "model_1" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) (None, 3, 256, 256) 0 __________________________________________________________________________________________________ conv2d_1 (Conv2D) (None, 32, 256, 256) 896 input_1[0][0] __________________________________________________________________________________________________ dropout_1 (Dropout) (None, 32, 256, 256) 0 conv2d_1[0][0] __________________________________________________________________________________________________ conv2d_2 (Conv2D) (None, 32, 256, 256) 9248 dropout_1[0][0] __________________________________________________________________________________________________ max_pooling2d_1 (MaxPooling2D) (None, 32, 128, 128) 0 conv2d_2[0][0] __________________________________________________________________________________________________ conv2d_3 (Conv2D) (None, 64, 128, 128) 18496 max_pooling2d_1[0][0] __________________________________________________________________________________________________ dropout_2 (Dropout) (None, 64, 128, 128) 0 conv2d_3[0][0] __________________________________________________________________________________________________ conv2d_4 (Conv2D) (None, 64, 128, 128) 36928 dropout_2[0][0] __________________________________________________________________________________________________ max_pooling2d_2 (MaxPooling2D) (None, 64, 64, 64) 0 conv2d_4[0][0] __________________________________________________________________________________________________ conv2d_5 (Conv2D) (None, 64, 64, 64) 36928 max_pooling2d_2[0][0] __________________________________________________________________________________________________ dropout_3 (Dropout) (None, 64, 64, 64) 0 conv2d_5[0][0] __________________________________________________________________________________________________ conv2d_6 (Conv2D) (None, 64, 64, 64) 36928 dropout_3[0][0] __________________________________________________________________________________________________ max_pooling2d_3 (MaxPooling2D) (None, 64, 32, 32) 0 conv2d_6[0][0] __________________________________________________________________________________________________ conv2d_7 (Conv2D) (None, 64, 32, 32) 36928 max_pooling2d_3[0][0] __________________________________________________________________________________________________ dropout_4 (Dropout) (None, 64, 32, 32) 0 conv2d_7[0][0] __________________________________________________________________________________________________ conv2d_8 (Conv2D) (None, 64, 32, 32) 36928 dropout_4[0][0] __________________________________________________________________________________________________ max_pooling2d_4 (MaxPooling2D) (None, 64, 16, 16) 0 conv2d_8[0][0] __________________________________________________________________________________________________ conv2d_9 (Conv2D) (None, 64, 16, 16) 36928 max_pooling2d_4[0][0] __________________________________________________________________________________________________ dropout_5 (Dropout) (None, 64, 16, 16) 0 conv2d_9[0][0] __________________________________________________________________________________________________ conv2d_10 (Conv2D) (None, 64, 16, 16) 36928 dropout_5[0][0] __________________________________________________________________________________________________ up_sampling2d_1 (UpSampling2D) (None, 64, 32, 32) 0 conv2d_10[0][0] __________________________________________________________________________________________________ concatenate_1 (Concatenate) (None, 128, 32, 32) 0 up_sampling2d_1[0][0] conv2d_8[0][0] __________________________________________________________________________________________________ conv2d_11 (Conv2D) (None, 64, 32, 32) 73792 concatenate_1[0][0] __________________________________________________________________________________________________ dropout_6 (Dropout) (None, 64, 32, 32) 0 conv2d_11[0][0] __________________________________________________________________________________________________ conv2d_12 (Conv2D) (None, 64, 32, 32) 36928 dropout_6[0][0] __________________________________________________________________________________________________ up_sampling2d_2 (UpSampling2D) (None, 64, 64, 64) 0 conv2d_12[0][0] __________________________________________________________________________________________________ concatenate_2 (Concatenate) (None, 128, 64, 64) 0 up_sampling2d_2[0][0] conv2d_6[0][0] __________________________________________________________________________________________________ conv2d_13 (Conv2D) (None, 64, 64, 64) 73792 concatenate_2[0][0] __________________________________________________________________________________________________ dropout_7 (Dropout) (None, 64, 64, 64) 0 conv2d_13[0][0] __________________________________________________________________________________________________ conv2d_14 (Conv2D) (None, 64, 64, 64) 36928 dropout_7[0][0] __________________________________________________________________________________________________ up_sampling2d_3 (UpSampling2D) (None, 64, 128, 128) 0 conv2d_14[0][0] __________________________________________________________________________________________________ concatenate_3 (Concatenate) (None, 128, 128, 128 0 up_sampling2d_3[0][0] conv2d_4[0][0] __________________________________________________________________________________________________ conv2d_15 (Conv2D) (None, 64, 128, 128) 73792 concatenate_3[0][0] __________________________________________________________________________________________________ dropout_8 (Dropout) (None, 64, 128, 128) 0 conv2d_15[0][0] __________________________________________________________________________________________________ conv2d_16 (Conv2D) (None, 64, 128, 128) 36928 dropout_8[0][0] __________________________________________________________________________________________________ up_sampling2d_4 (UpSampling2D) (None, 64, 256, 256) 0 conv2d_16[0][0] __________________________________________________________________________________________________ concatenate_4 (Concatenate) (None, 96, 256, 256) 0 up_sampling2d_4[0][0] conv2d_2[0][0] __________________________________________________________________________________________________ conv2d_17 (Conv2D) (None, 32, 256, 256) 27680 concatenate_4[0][0] __________________________________________________________________________________________________ dropout_9 (Dropout) (None, 32, 256, 256) 0 conv2d_17[0][0] __________________________________________________________________________________________________ conv2d_18 (Conv2D) (None, 32, 256, 256) 9248 dropout_9[0][0] __________________________________________________________________________________________________ conv2d_19 (Conv2D) (None, 1, 256, 256) 33 conv2d_18[0][0] ================================================================================================== Total params: 656,257 Trainable params: 656,257 Non-trainable params: 0 __________________________________________________________________________________________________
/home/artem/miniconda3/lib/python3.7/site-packages/ipykernel_launcher.py:50: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor("in..., outputs=Tensor("co...)`
Accessing data, preparing train/validation sets division:
X = h5f['RIM-ONE v3/256 px/images']
Y = h5f['RIM-ONE v3/256 px/disc']
X, Y
(<HDF5 dataset "images": shape (159, 256, 256, 3), type "|u1">, <HDF5 dataset "disc": shape (159, 256, 256, 1), type "|u1">)
train_idx_cv, test_idx_cv = [], []
for _train_idx, _test_idx in KFold(n_splits=5, random_state=1).split(X):
print(_train_idx, _test_idx)
train_idx_cv.append(_train_idx)
test_idx_cv.append(_test_idx)
[ 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158] [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31] [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158] [32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63] [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158] [64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95] [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158] [ 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127] [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127] [128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158]
#train_idx = h5f['RIM-ONE v3/train_idx_driu']
#test_idx = h5f['RIM-ONE v3/test_idx_driu']
train_idx = train_idx_cv[0]
test_idx = test_idx_cv[0]
len(X), len(train_idx), len(test_idx)
(159, 127, 32)
train_idg = DualImageDataGenerator(#rescale=1/255.0,
#samplewise_center=True, samplewise_std_normalization=True,
horizontal_flip=True, vertical_flip=True,
rotation_range=50, width_shift_range=0.15, height_shift_range=0.15,
zoom_range=(0.7, 1.3),
fill_mode='constant', cval=0.0)
test_idg = DualImageDataGenerator()
def preprocess(batch_X, batch_y, train_or_test='train'):
batch_X = batch_X / 255.0
batch_y = batch_y / 255.0
if train_or_test == 'train':
batch_X, batch_y = next(train_idg.flow(batch_X, batch_y, batch_size=len(batch_X), shuffle=False))
elif train_or_test == 'test':
batch_X, batch_y = next(test_idg.flow(batch_X, batch_y, batch_size=len(batch_X), shuffle=False))
batch_X = th_to_tf_encoding(batch_X)
batch_X = [skimage.exposure.equalize_adapthist(batch_X[i])
for i in range(len(batch_X))]
batch_X = np.array(batch_X)
batch_X = tf_to_th_encoding(batch_X)
return batch_X, batch_y
def data_generator(X, y, train_or_test='train', batch_size=3, return_orig=False, stationary=False):
while True:
if train_or_test == 'train':
idx = np.random.choice(train_idx, size=batch_size)
elif train_or_test == 'test':
if stationary:
idx = test_idx[:batch_size]
else:
idx = np.random.choice(test_idx, size=batch_size)
batch_X = [X[i] for i in idx]
batch_X = np.array(batch_X).copy()
batch_y = [y[i] for i in idx]
batch_y = np.array(batch_y).copy()
batch_X = tf_to_th_encoding(batch_X)
batch_y = tf_to_th_encoding(batch_y)
if return_orig:
batch_X_orig, batch_Y_orig = batch_X.copy(), batch_y.copy()
batch_X, batch_y = preprocess(batch_X, batch_y, train_or_test)
if not return_orig:
yield batch_X, batch_y
else:
yield batch_X, batch_y, batch_X_orig, batch_Y_orig
Testing the data generator and generator for augmented data:
gen = data_generator(X, Y, 'train', batch_size=1)
batch = next(gen)
batch[0].shape
(1, 3, 256, 256)
fig = plt.imshow(np.rollaxis(batch[0][0], 0, 3))
#plt.colorbar(mappable=fig)
plt.show()
plt.imshow(batch[1][0][0], cmap=plt.cm.Greys_r); plt.show()
arch_name = "U-Net light, on RIM-ONE v3 256 px fold 0, SGD, high augm, CLAHE, log_dice loss"
weights_folder = os.path.join(os.path.dirname(os.getcwd()), 'models_weights',
'{},{}/'.format(datetime.now().strftime('%d.%m,%H:%M'), arch_name))
print(weights_folder)
def folder(folder_name):
if not os.path.exists(folder_name):
os.makedirs(folder_name)
return folder_name
X_valid, Y_valid = next(data_generator(X, Y, train_or_test='test', batch_size=100, stationary=True))
plt.imshow(np.rollaxis(X_valid[0], 0, 3)); plt.show()
print(X_valid.shape, Y_valid.shape)
(32, 3, 256, 256) (32, 1, 256, 256)
If a pretrained model needs to be used, first run "Loading model" section below and then go the "Comprehensive visual check", skipping this section.
history = model.fit_generator(data_generator(X, Y, train_or_test='train', batch_size=1),
steps_per_epoch=99,
max_queue_size=1,
validation_data=(X_valid, Y_valid),
#validation_data=data_generator(X, Y, train_or_test='test', batch_size=1),
#nb_val_samples=100,
epochs=500, verbose=1,
callbacks=[CSVLogger(os.path.join(folder(weights_folder), 'training_log.csv')),
#ReduceLROnPlateau(monitor='val_loss', mode='min', factor=0.5, verbose=1, patience=40),
ModelCheckpoint(os.path.join(folder(weights_folder),
#'weights.ep-{epoch:02d}-val_mean_IOU-{val_mean_IOU_gpu:.2f}_val_loss_{val_loss:.2f}.hdf5',
'last_checkpoint.hdf5'),
monitor='val_loss', mode='min', save_best_only=True,
save_weights_only=False, verbose=0)])
pred_iou, pred_dice = [], []
for i, img_no in enumerate(test_idx):
print('image #{}'.format(img_no))
img = X[img_no]
batch_X = X_valid[i:i + 1]
batch_y = Y_valid[i:i + 1]
pred = (model.predict(batch_X)[0, 0] > 0.5).astype(np.float64)
#corr = Y[img_no][..., 0]
corr = th_to_tf_encoding(batch_y)[0, ..., 0]
# mean filtering:
#pred = mh.mean_filter(pred, Bc=mh.disk(10)) > 0.5
fig = plt.figure(figsize=(9, 4))
ax = fig.add_subplot(1, 3, 1)
ax.imshow(pred, cmap=plt.cm.Greys_r)
ax.set_title('Predicted')
ax = fig.add_subplot(1, 3, 2)
ax.imshow(corr, cmap=plt.cm.Greys_r)
ax.set_title('Correct')
ax = fig.add_subplot(1, 3, 3)
#ax.imshow(img)
ax.imshow(th_to_tf_encoding(batch_X)[0])
ax.set_title('Image')
plt.show()
cur_iou = K.eval(mean_IOU_gpu(pred[None, None, ...], corr[None, None, ...]))
cur_dice = K.eval(dice(pred[None, None, ...], corr[None, None, ...]))
print('IOU: {}\nDice: {}'.format(cur_iou, cur_dice))
pred_iou.append(cur_iou)
pred_dice.append(cur_dice)
image #0
IOU: 0.8816186189651489 Dice: 0.944140385555706 image #1
IOU: 0.8961654305458069 Dice: 0.952512119605312 image #2
IOU: 0.9015904664993286 Dice: 0.9553385003360111 image #3
IOU: 0.8172494173049927 Dice: 0.9083420941773463 image #4
IOU: 0.889970064163208 Dice: 0.9492029322292237 image #5
IOU: 0.858206033706665 Dice: 0.9300793217091029 image #6
IOU: 0.8313539028167725 Dice: 0.9152745710378027 image #7
IOU: 0.8521560430526733 Dice: 0.9289183684966821 image #8
IOU: 0.9362050294876099 Dice: 0.971991061814878 image #9
IOU: 0.8662033081054688 Dice: 0.9375384127618763 image #10
IOU: 0.7161670923233032 Dice: 0.8412830031874182 image #11
IOU: 0.8760484457015991 Dice: 0.9408158772195463 image #12
IOU: 0.9020586609840393 Dice: 0.9562561050223877 image #13
IOU: 0.8646368384361267 Dice: 0.9338180973871951 image #14
IOU: 0.8894058465957642 Dice: 0.9477895275961465 image #15
IOU: 0.8265966176986694 Dice: 0.911744908846224 image #16
IOU: 0.8359081149101257 Dice: 0.9180478236485754 image #17
IOU: 0.8375751972198486 Dice: 0.919305203472488 image #18
IOU: 0.8221626281738281 Dice: 0.9099414826076261 image #19
IOU: 0.8900970816612244 Dice: 0.9480220327861764 image #20
IOU: 0.8352692723274231 Dice: 0.9148233774710852 image #21
IOU: 0.8083933591842651 Dice: 0.9002822136197898 image #22
IOU: 0.8826213479042053 Dice: 0.9448930512273794 image #23
IOU: 0.9258196949958801 Dice: 0.9670243083173704 image #24
IOU: 0.9068605303764343 Dice: 0.9595719452492445 image #25
IOU: 0.9389384984970093 Dice: 0.9740978833951929 image #26
IOU: 0.7768411636352539 Dice: 0.880669597062454 image #27
IOU: 0.9319126605987549 Dice: 0.9718210533019517 image #28
IOU: 0.8445137739181519 Dice: 0.923892559623657 image #29
IOU: 0.9158003926277161 Dice: 0.9625049040473955 image #30
IOU: 0.8827614188194275 Dice: 0.9459575162743651 image #31
IOU: 0.7352941036224365 Dice: 0.8552486708050341
Acquiring scores for the validation set:
print(np.mean(pred_iou))
print(np.mean(pred_dice))
0.8617625 0.9319109034341452
Showing the best and the worst cases:
def show_img_pred_corr(i, file_suffix): # i is index of image in test_idx
img_no = test_idx[i]
batch_X = X[img_no:img_no + 1]
batch_X = tf_to_th_encoding(batch_X)
batch_y = Y[img_no:img_no + 1]
batch_y = tf_to_th_encoding(batch_y)
batch_X, batch_y = preprocess(batch_X, batch_y, 'test')
pred = model.predict(batch_X)[0, 0] > 0.5
#corr = Y[img_no][..., 0]
corr = th_to_tf_encoding(batch_y)[0, ..., 0]
fig = plt.figure(figsize=(9, 4))
ax = fig.add_subplot(1, 3, 1)
ax.imshow(pred, cmap=plt.cm.Greys_r)
ax.set_title('Predicted')
ax = fig.add_subplot(1, 3, 2)
ax.imshow(corr, cmap=plt.cm.Greys_r)
ax.set_title('Correct')
ax = fig.add_subplot(1, 3, 3)
#ax.imshow(img)
ax.imshow(X[img_no])
ax.set_title('Image')
plt.show()
plt.imsave('od_rim_one_v3_fold_0_{}_case_image.png'.format(file_suffix), X[img_no])
plt.imsave('od_rim_one_v3_fold_0_{}_case_pred.png'.format(file_suffix), pred, cmap=plt.cm.Greys_r)
plt.imsave('od_rim_one_v3_fold_0_{}_case_corr.png'.format(file_suffix), corr, cmap=plt.cm.Greys_r)
best_idx = np.argmax(pred_iou)
worst_idx = np.argmin(pred_iou)
show_img_pred_corr(best_idx, 'best')
print('IOU: {}, Dice: {} (best)'.format(pred_iou[best_idx], pred_dice[best_idx]))
show_img_pred_corr(worst_idx, 'worst')
print('IOU: {}, Dice: {} (worst)'.format(pred_iou[worst_idx], pred_dice[worst_idx]))
IOU: 0.9389384984970093, Dice: 0.9740978833951929 (best)
IOU: 0.7161670923233032, Dice: 0.8412830031874182 (worst)
load_model = True # lock
if not load_model:
print('load_model == False')
else:
# UNCOMMENT APPROPRIATE LINE(S) BELOW:
# specify file:
#model_path = '../models_weights/01.11,22:38,U-Net on DRIONS-DB 256 px, Adam, augm, log_dice loss/' \
# 'weights.ep-20-val_mean_IOU-0.81_val_loss_0.08.hdf5'
# or get the most recently altered file in a folder:
model_folder = os.path.join(os.path.dirname(os.getcwd()), 'models_weights', '03.03,14_19,U-Net light, on RIM-ONE v3 256 px fold 0, SGD, high augm, CLAHE, log_dice loss')
model_path = max(glob.glob(os.path.join(model_folder, '*.hdf5')), key=os.path.getctime)
if load_model and not os.path.exists(model_path):
raise Exception('`model_path` does not exist')
print('Loading weights from', model_path)
if load_model:
#with open(model_path + ' arch.json') as arch_file:
# json_string = arch_file.read()
#new_model = model_from_json(json_string)
model.load_weights(model_path)
# Reading log statistics
import pandas as pd
log_path = os.path.join(model_folder, 'training_log.csv')
if os.path.exists(log_path):
log = pd.read_csv(log_path)
if log['epoch'].dtype != 'int64':
log = log.loc[log.epoch != 'epoch']
print('\nmax val mean IOU: {}, at row:'.format(log['val_mean_IOU_gpu'].max()))
print(log.loc[log['val_mean_IOU_gpu'].idxmax()])
if 'val_dice_metric' in log.columns:
print('\n' + 'max val dice_metric: {}, at row:'.format(log['val_dice_metric'].max()))
print(log.loc[log['val_dice_metric'].idxmax()])
if 'val_dice' in log.columns:
print('\n' + 'max val dice: {}, at row:'.format(log['val_dice'].max()))
print(log.loc[log['val_dice'].idxmax()])
Loading weights from /home/artem/Загрузки/optic-nerve-cnn/models_weights/03.03,14_19,U-Net light, on RIM-ONE v3 256 px fold 0, SGD, high augm, CLAHE, log_dice loss/last_checkpoint.hdf5 max val mean IOU: 0.8809200096879999, at row: epoch 196.000000 dice_metric 0.902209 loss 0.071331 mean_IOU_gpu 0.826438 val_dice_metric 0.936309 val_loss 0.042184 val_mean_IOU_gpu 0.880920 Name: 196, dtype: float64 max val dice_metric: 0.936309117824, at row: epoch 196.000000 dice_metric 0.902209 loss 0.071331 mean_IOU_gpu 0.826438 val_dice_metric 0.936309 val_loss 0.042184 val_mean_IOU_gpu 0.880920 Name: 196, dtype: float64