512 px cropped by Optic Disc area and resized to 128 px images were used.
You can either train your model or upload a pre-trained one from: ../models_weights/02.03,13:57,OD Cup, U-Net light on DRISHTI-GS 512 px cropped to OD 128 px fold 0, SGD, log_dice loss/last_checkpoint.hdf5
%load_ext autoreload
%autoreload 2
import os
import glob
from datetime import datetime
#import warnings
#warnings.simplefilter('ignore')
import scipy as sp
import scipy.ndimage
import numpy as np
import pandas as pd
import tensorflow as tf
import skimage
import skimage.exposure
import skimage.transform
import mahotas as mh
from sklearn.model_selection import KFold
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
import h5py
from tqdm import tqdm_notebook
from IPython.display import display
from dual_IDG import DualImageDataGenerator
Using TensorFlow backend.
import keras
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, Activation, Flatten, BatchNormalization, \
Conv2D, MaxPooling2D, ZeroPadding2D, Input, Embedding, \
Lambda, UpSampling2D, Cropping2D, Concatenate
from keras.utils import np_utils
from keras.optimizers import SGD, Adam
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau, CSVLogger
from keras.preprocessing.image import ImageDataGenerator
from keras import backend as K
print('Keras version:', keras.__version__)
print('TensorFlow version:', tf.__version__)
Keras version: 2.3.1 TensorFlow version: 2.0.0
K.set_image_data_format('channels_first')
def mean_IOU_gpu(X, Y):
"""Computes mean Intersection-over-Union (IOU) for two arrays of binary images.
Assuming X and Y are of shape (n_images, w, h)."""
#X_fl = K.clip(K.batch_flatten(X), K.epsilon(), 1.)
#Y_fl = K.clip(K.batch_flatten(Y), K.epsilon(), 1.)
X_fl = K.clip(K.batch_flatten(X), 0., 1.)
Y_fl = K.clip(K.batch_flatten(Y), 0., 1.)
X_fl = K.cast(K.greater(X_fl, 0.5), 'float32')
Y_fl = K.cast(K.greater(Y_fl, 0.5), 'float32')
intersection = K.sum(X_fl * Y_fl, axis=1)
union = K.sum(K.maximum(X_fl, Y_fl), axis=1)
# if union == 0, it follows that intersection == 0 => score should be 0.
union = K.switch(K.equal(union, 0), K.ones_like(union), union)
return K.mean(intersection / K.cast(union, 'float32'))
def mean_IOU_gpu_loss(X, Y):
return -mean_IOU_gpu(X, Y)
def dice(y_true, y_pred):
# Workaround for shape bug. For some reason y_true shape was not being set correctly
#y_true.set_shape(y_pred.get_shape())
# Without K.clip, K.sum() behaves differently when compared to np.count_nonzero()
#y_true_f = K.clip(K.batch_flatten(y_true), K.epsilon(), 1.)
#y_pred_f = K.clip(K.batch_flatten(y_pred), K.epsilon(), 1.)
y_true_f = K.clip(K.batch_flatten(y_true), 0., 1.)
y_pred_f = K.clip(K.batch_flatten(y_pred), 0., 1.)
#y_pred_f = K.greater(y_pred_f, 0.5)
intersection = 2 * K.sum(y_true_f * y_pred_f, axis=1)
union = K.sum(y_true_f * y_true_f, axis=1) + K.sum(y_pred_f * y_pred_f, axis=1)
return K.mean(intersection / union)
def dice_loss(y_true, y_pred):
return -dice(y_true, y_pred)
def log_dice_loss(y_true, y_pred):
return -K.log(dice(y_true, y_pred))
def dice_metric(y_true, y_pred):
"""An exact Dice score for binary tensors."""
y_true_f = K.cast(K.greater(y_true, 0.5), 'float32')
y_pred_f = K.cast(K.greater(y_pred, 0.5), 'float32')
return dice(y_true_f, y_pred_f)
def tf_to_th_encoding(X):
return np.rollaxis(X, 3, 1)
def th_to_tf_encoding(X):
return np.rollaxis(X, 1, 4)
# h5f = h5py.File(os.path.join(os.path.dirname(os.getcwd()), 'data', 'hdf5_datasets', 'all_data.hdf5'), 'r')
h5f = h5py.File(os.path.join(os.path.dirname(os.getcwd()), 'data', 'hdf5_datasets', 'DRISHTI_GS.hdf5'), 'r')
def get_unet_light(img_rows=256, img_cols=256):
inputs = Input((3, img_rows, img_cols))
conv1 = Conv2D(32, kernel_size=3, activation='relu', padding='same')(inputs)
conv1 = Dropout(0.3)(conv1)
conv1 = Conv2D(32, kernel_size=3, activation='relu', padding='same')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(pool1)
conv2 = Dropout(0.3)(conv2)
conv2 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(pool2)
conv3 = Dropout(0.3)(conv3)
conv3 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(pool3)
conv4 = Dropout(0.3)(conv4)
conv4 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(pool4)
conv5 = Dropout(0.3)(conv5)
conv5 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(conv5)
up6 = Concatenate(axis=1)([UpSampling2D(size=(2, 2))(conv5), conv4])
conv6 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(up6)
conv6 = Dropout(0.3)(conv6)
conv6 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(conv6)
up7 = Concatenate(axis=1)([UpSampling2D(size=(2, 2))(conv6), conv3])
conv7 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(up7)
conv7 = Dropout(0.3)(conv7)
conv7 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(conv7)
up8 = Concatenate(axis=1)([UpSampling2D(size=(2, 2))(conv7), conv2])
conv8 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(up8)
conv8 = Dropout(0.3)(conv8)
conv8 = Conv2D(64, kernel_size=3, activation='relu', padding='same')(conv8)
up9 = Concatenate(axis=1)([UpSampling2D(size=(2, 2))(conv8), conv1])
conv9 = Conv2D(32, kernel_size=3, activation='relu', padding='same')(up9)
conv9 = Dropout(0.3)(conv9)
conv9 = Conv2D(32, kernel_size=3, activation='relu', padding='same')(conv9)
conv10 = Conv2D(1, kernel_size=1, activation='sigmoid', padding='same')(conv9)
#conv10 = Flatten()(conv10)
model = Model(input=inputs, output=conv10)
return model
model = get_unet_light(img_rows=128, img_cols=128)
model.compile(optimizer=SGD(lr=1e-4, momentum=0.95),
loss=log_dice_loss,
metrics=[mean_IOU_gpu, dice_metric])
/home/artem/miniconda3/lib/python3.7/site-packages/ipykernel_launcher.py:50: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor("in..., outputs=Tensor("co...)`
model.summary()
Model: "model_1" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) (None, 3, 128, 128) 0 __________________________________________________________________________________________________ conv2d_1 (Conv2D) (None, 32, 128, 128) 896 input_1[0][0] __________________________________________________________________________________________________ dropout_1 (Dropout) (None, 32, 128, 128) 0 conv2d_1[0][0] __________________________________________________________________________________________________ conv2d_2 (Conv2D) (None, 32, 128, 128) 9248 dropout_1[0][0] __________________________________________________________________________________________________ max_pooling2d_1 (MaxPooling2D) (None, 32, 64, 64) 0 conv2d_2[0][0] __________________________________________________________________________________________________ conv2d_3 (Conv2D) (None, 64, 64, 64) 18496 max_pooling2d_1[0][0] __________________________________________________________________________________________________ dropout_2 (Dropout) (None, 64, 64, 64) 0 conv2d_3[0][0] __________________________________________________________________________________________________ conv2d_4 (Conv2D) (None, 64, 64, 64) 36928 dropout_2[0][0] __________________________________________________________________________________________________ max_pooling2d_2 (MaxPooling2D) (None, 64, 32, 32) 0 conv2d_4[0][0] __________________________________________________________________________________________________ conv2d_5 (Conv2D) (None, 64, 32, 32) 36928 max_pooling2d_2[0][0] __________________________________________________________________________________________________ dropout_3 (Dropout) (None, 64, 32, 32) 0 conv2d_5[0][0] __________________________________________________________________________________________________ conv2d_6 (Conv2D) (None, 64, 32, 32) 36928 dropout_3[0][0] __________________________________________________________________________________________________ max_pooling2d_3 (MaxPooling2D) (None, 64, 16, 16) 0 conv2d_6[0][0] __________________________________________________________________________________________________ conv2d_7 (Conv2D) (None, 64, 16, 16) 36928 max_pooling2d_3[0][0] __________________________________________________________________________________________________ dropout_4 (Dropout) (None, 64, 16, 16) 0 conv2d_7[0][0] __________________________________________________________________________________________________ conv2d_8 (Conv2D) (None, 64, 16, 16) 36928 dropout_4[0][0] __________________________________________________________________________________________________ max_pooling2d_4 (MaxPooling2D) (None, 64, 8, 8) 0 conv2d_8[0][0] __________________________________________________________________________________________________ conv2d_9 (Conv2D) (None, 64, 8, 8) 36928 max_pooling2d_4[0][0] __________________________________________________________________________________________________ dropout_5 (Dropout) (None, 64, 8, 8) 0 conv2d_9[0][0] __________________________________________________________________________________________________ conv2d_10 (Conv2D) (None, 64, 8, 8) 36928 dropout_5[0][0] __________________________________________________________________________________________________ up_sampling2d_1 (UpSampling2D) (None, 64, 16, 16) 0 conv2d_10[0][0] __________________________________________________________________________________________________ concatenate_1 (Concatenate) (None, 128, 16, 16) 0 up_sampling2d_1[0][0] conv2d_8[0][0] __________________________________________________________________________________________________ conv2d_11 (Conv2D) (None, 64, 16, 16) 73792 concatenate_1[0][0] __________________________________________________________________________________________________ dropout_6 (Dropout) (None, 64, 16, 16) 0 conv2d_11[0][0] __________________________________________________________________________________________________ conv2d_12 (Conv2D) (None, 64, 16, 16) 36928 dropout_6[0][0] __________________________________________________________________________________________________ up_sampling2d_2 (UpSampling2D) (None, 64, 32, 32) 0 conv2d_12[0][0] __________________________________________________________________________________________________ concatenate_2 (Concatenate) (None, 128, 32, 32) 0 up_sampling2d_2[0][0] conv2d_6[0][0] __________________________________________________________________________________________________ conv2d_13 (Conv2D) (None, 64, 32, 32) 73792 concatenate_2[0][0] __________________________________________________________________________________________________ dropout_7 (Dropout) (None, 64, 32, 32) 0 conv2d_13[0][0] __________________________________________________________________________________________________ conv2d_14 (Conv2D) (None, 64, 32, 32) 36928 dropout_7[0][0] __________________________________________________________________________________________________ up_sampling2d_3 (UpSampling2D) (None, 64, 64, 64) 0 conv2d_14[0][0] __________________________________________________________________________________________________ concatenate_3 (Concatenate) (None, 128, 64, 64) 0 up_sampling2d_3[0][0] conv2d_4[0][0] __________________________________________________________________________________________________ conv2d_15 (Conv2D) (None, 64, 64, 64) 73792 concatenate_3[0][0] __________________________________________________________________________________________________ dropout_8 (Dropout) (None, 64, 64, 64) 0 conv2d_15[0][0] __________________________________________________________________________________________________ conv2d_16 (Conv2D) (None, 64, 64, 64) 36928 dropout_8[0][0] __________________________________________________________________________________________________ up_sampling2d_4 (UpSampling2D) (None, 64, 128, 128) 0 conv2d_16[0][0] __________________________________________________________________________________________________ concatenate_4 (Concatenate) (None, 96, 128, 128) 0 up_sampling2d_4[0][0] conv2d_2[0][0] __________________________________________________________________________________________________ conv2d_17 (Conv2D) (None, 32, 128, 128) 27680 concatenate_4[0][0] __________________________________________________________________________________________________ dropout_9 (Dropout) (None, 32, 128, 128) 0 conv2d_17[0][0] __________________________________________________________________________________________________ conv2d_18 (Conv2D) (None, 32, 128, 128) 9248 dropout_9[0][0] __________________________________________________________________________________________________ conv2d_19 (Conv2D) (None, 1, 128, 128) 33 conv2d_18[0][0] ================================================================================================== Total params: 656,257 Trainable params: 656,257 Non-trainable params: 0 __________________________________________________________________________________________________
Accessing data, preparing train/validation sets division:
# Loading full images of desired resolution:
X = h5f['DRISHTI-GS/512 px/images']
Y = h5f['DRISHTI-GS/512 px/cup']
disc_locations = h5f['DRISHTI-GS/512 px/disc_locations']
X, Y
(<HDF5 dataset "images": shape (50, 512, 512, 3), type "|u1">, <HDF5 dataset "cup": shape (50, 512, 512, 1), type "|u1">)
#train_idx = h5f['RIM-ONE v3/train_idx_driu']
#test_idx = h5f['RIM-ONE v3/test_idx_driu']
train_idx = h5f['DRISHTI-GS/train_idx_cv'][0]
test_idx = h5f['DRISHTI-GS/test_idx_cv'][0]
print(len(X), len(train_idx), len(test_idx))
print(train_idx, test_idx)
50 40 10 [ 0 1 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 28 30 32 33 34 36 37 39 41 42 43 44 45 47 49] [ 2 3 27 29 31 35 38 40 46 48]
train_idg = DualImageDataGenerator(#rescale=1/255.0,
#samplewise_center=True, samplewise_std_normalization=True,
horizontal_flip=True, vertical_flip=True,
rotation_range=20, width_shift_range=0.1, height_shift_range=0.1,
zoom_range=(0.8, 1.2),
fill_mode='constant', cval=0.0)
test_idg = DualImageDataGenerator()
def preprocess(batch_X, batch_y, train_or_test='train'):
batch_X = batch_X / 255.0
# the following line thresholds segmentation mask for DRISHTI-GS, since it contains averaged soft maps:
batch_y = batch_y >= 0.5
if train_or_test == 'train':
batch_X, batch_y = next(train_idg.flow(batch_X, batch_y, batch_size=len(batch_X), shuffle=False))
elif train_or_test == 'test':
batch_X, batch_y = next(test_idg.flow(batch_X, batch_y, batch_size=len(batch_X), shuffle=False))
batch_X = th_to_tf_encoding(batch_X)
batch_X = [skimage.exposure.equalize_adapthist(batch_X[i])
for i in range(len(batch_X))]
batch_X = np.array(batch_X)
batch_X = tf_to_th_encoding(batch_X)
return batch_X, batch_y
def data_generator(X, y, resize_to=128, train_or_test='train', batch_size=3, return_orig=False, stationary=False):
"""Gets random batch of data,
divides by 255,
feeds it to DualImageDataGenerator."""
while True:
if train_or_test == 'train':
idx = np.random.choice(train_idx, size=batch_size)
elif train_or_test == 'test':
if stationary:
idx = test_idx[:batch_size]
else:
idx = np.random.choice(test_idx, size=batch_size)
batch_X = [X[i][disc_locations[i][0]:disc_locations[i][2], disc_locations[i][1]:disc_locations[i][3]]
for i in idx]
batch_X = [np.rollaxis(img, 2) for img in batch_X]
batch_X = [skimage.transform.resize(np.rollaxis(img, 0, 3), (resize_to, resize_to))
for img in batch_X]
batch_X = np.array(batch_X).copy()
batch_y = [y[i][disc_locations[i][0]:disc_locations[i][2], disc_locations[i][1]:disc_locations[i][3]]
for i in idx]
batch_y = [img[..., 0] for img in batch_y]
batch_y = [skimage.transform.resize(img, (resize_to, resize_to))[..., None] for img in batch_y]
batch_y = np.array(batch_y).copy()
batch_X = tf_to_th_encoding(batch_X)
batch_y = tf_to_th_encoding(batch_y)
if return_orig:
batch_X_orig, batch_Y_orig = batch_X.copy(), batch_y.copy()
batch_X, batch_y = preprocess(batch_X, batch_y, train_or_test)
if not return_orig:
yield batch_X, batch_y
else:
yield batch_X, batch_y, batch_X_orig, batch_Y_orig
Testing the data generator and generator for augmented data:
gen = data_generator(X, Y, 128, 'train', batch_size=1)
batch = next(gen)
batch[0].shape
(1, 3, 128, 128)
fig = plt.imshow(np.rollaxis(batch[0][0], 0, 3))
#plt.colorbar(mappable=fig)
plt.show()
plt.imshow(batch[1][0][0], cmap=plt.cm.Greys_r); plt.colorbar(); plt.show()
arch_name = "OD Cup, U-Net light on DRISHTI-GS 512 px cropped to OD 128 px fold 0, SGD, log_dice loss"
weights_folder = os.path.join(os.path.dirname(os.getcwd()), 'models_weights',
'{},{}'.format(datetime.now().strftime('%d.%m,%H:%M'), arch_name))
print(weights_folder)
def folder(folder_name):
if not os.path.exists(folder_name):
os.makedirs(folder_name)
return folder_name
X_valid, Y_valid = next(data_generator(X, Y, train_or_test='test', batch_size=100, stationary=True))
plt.imshow(np.rollaxis(X_valid[0], 0, 3)); plt.show()
print(X_valid.shape, Y_valid.shape)
(10, 3, 128, 128) (10, 1, 128, 128)
If a pretrained model needs to be used, first run "Loading model" section below and then go the "Comprehensive visual check", skipping this section.
history = model.fit_generator(data_generator(X, Y, train_or_test='train', batch_size=1),
steps_per_epoch=99,
max_queue_size=1,
validation_data=(X_valid, Y_valid),
#validation_data=data_generator(X, Y, train_or_test='test', batch_size=1),
#nb_val_samples=100,
epochs=500, verbose=1,
callbacks=[CSVLogger(os.path.join(folder(weights_folder), 'training_log.csv')),
#ReduceLROnPlateau(monitor='val_loss', mode='min', factor=0.5, verbose=1, patience=40),
ModelCheckpoint(os.path.join(folder(weights_folder),
#'weights.ep-{epoch:02d}-val_mean_IOU-{val_mean_IOU_gpu:.2f}_val_loss_{val_loss:.2f}.hdf5',
'last_checkpoint.hdf5'),
monitor='val_loss', mode='min', save_best_only=True,
save_weights_only=False, verbose=0)])
pred_iou, pred_dice = [], []
for i, img_no in enumerate(test_idx):
print('image #{}'.format(img_no))
img = X[img_no]
batch_X = X_valid[i:i + 1]
batch_y = Y_valid[i:i + 1]
pred = (model.predict(batch_X)[0, 0] > 0.5).astype(np.float64)
#corr = Y[img_no][..., 0]
corr = th_to_tf_encoding(batch_y)[0, ..., 0]
# mean filtering:
#pred = mh.mean_filter(pred, Bc=mh.disk(10)) > 0.5
fig = plt.figure(figsize=(9, 4))
ax = fig.add_subplot(1, 3, 1)
ax.imshow(pred, cmap=plt.cm.Greys_r)
ax.set_title('Predicted')
ax = fig.add_subplot(1, 3, 2)
ax.imshow(corr, cmap=plt.cm.Greys_r)
ax.set_title('Correct')
ax = fig.add_subplot(1, 3, 3)
#ax.imshow(img)
ax.imshow(th_to_tf_encoding(batch_X)[0])
ax.set_title('Image')
plt.show()
cur_iou = K.eval(mean_IOU_gpu(pred[None, None, ...], corr[None, None, ...]))
cur_dice = K.eval(dice(pred[None, None, ...], corr[None, None, ...]))
print('IOU: {}\nDice: {}'.format(cur_iou, cur_dice))
pred_iou.append(cur_iou)
pred_dice.append(cur_dice)
image #2
IOU: 0.6064245700836182 Dice: 0.754999130585985 image #3
IOU: 0.7233606576919556 Dice: 0.8394768133174791 image #27
IOU: 0.9320363402366638 Dice: 0.9648227712137487 image #29
IOU: 0.8829098343849182 Dice: 0.9378142365705213 image #31
IOU: 0.8561776280403137 Dice: 0.9225169006760271 image #35
IOU: 0.828913152217865 Dice: 0.9064543645152118 image #38
IOU: 0.854667067527771 Dice: 0.9216393442622951 image #40
IOU: 0.8452631831169128 Dice: 0.9161437535653166 image #46
IOU: 0.49794802069664 Dice: 0.6648401826484018 image #48
IOU: 0.5322209596633911 Dice: 0.694705219677056
Acquiring scores for the validation set:
print(np.mean(pred_iou))
print(np.mean(pred_dice))
0.7559922 0.8523412717032043
load_model = True # lock
if not load_model:
print('load_model == False')
else:
# specify file:
#model_path = '../models_weights/01.11,22:38,U-Net on DRIONS-DB 256 px, Adam, augm, log_dice loss/' \
# 'weights.ep-20-val_mean_IOU-0.81_val_loss_0.08.hdf5'
# or get the most recent file in a folder:
model_folder = os.path.join(os.path.dirname(os.getcwd()), 'models_weights', '02.03,13_57,OD Cup, U-Net light on DRISHTI-GS 512 px cropped to OD 128 px fold 0, SGD, log_dice loss')
model_path = max(glob.glob(os.path.join(model_folder, '*.hdf5')), key=os.path.getctime)
if load_model and not os.path.exists(model_path):
raise Exception('`model_path` does not exist')
print('Loading weights from', model_path)
if load_model:
#with open(model_path + ' arch.json') as arch_file:
# json_string = arch_file.read()
#new_model = model_from_json(json_string)
model.load_weights(model_path)
# Reading log statistics
import pandas as pd
log_path = os.path.join(model_folder, 'training_log.csv')
if os.path.exists(log_path):
log = pd.read_csv(log_path)
if log['epoch'].dtype != 'int64':
log = log.loc[log.epoch != 'epoch']
print('\nmax val mean IOU: {}, at row:'.format(log['val_mean_IOU_gpu'].max()))
print(log.loc[log['val_mean_IOU_gpu'].idxmax()])
if 'val_dice_metric' in log.columns:
print('\n' + 'max val dice_metric: {}, at row:'.format(log['val_dice_metric'].max()))
print(log.loc[log['val_dice_metric'].idxmax()])
Loading weights from /home/artem/Загрузки/optic-nerve-cnn/models_weights/02.03,13_57,OD Cup, U-Net light on DRISHTI-GS 512 px cropped to OD 128 px fold 0, SGD, log_dice loss/last_checkpoint.hdf5 max val mean IOU: 0.745275825262, at row: epoch 495.000000 dice_metric 0.810811 loss 0.163871 mean_IOU_gpu 0.695188 val_dice_metric 0.846246 val_loss 0.138325 val_mean_IOU_gpu 0.745276 Name: 495, dtype: float64 max val dice_metric: 0.8462459504600001, at row: epoch 495.000000 dice_metric 0.810811 loss 0.163871 mean_IOU_gpu 0.695188 val_dice_metric 0.846246 val_loss 0.138325 val_mean_IOU_gpu 0.745276 Name: 495, dtype: float64