!wget https://data.mendeley.com/datasets/rscbjbr9sj/2/files/9e8f7acf-7d3a-487f-8eb5-0bd3255b9685/OCT2017.tar.gz
--2019-07-18 19:01:42-- https://data.mendeley.com/datasets/rscbjbr9sj/2/files/9e8f7acf-7d3a-487f-8eb5-0bd3255b9685/OCT2017.tar.gz Resolving data.mendeley.com (data.mendeley.com)... 63.32.214.65, 54.229.42.128 Connecting to data.mendeley.com (data.mendeley.com)|63.32.214.65|:443... connected. HTTP request sent, awaiting response... 302 Found Location: https://com-mendeley-internal.s3.amazonaws.com/platform/rdm/production/a2609dd9-bc04-4417-9e77-2406e74a4f44?response-content-disposition=inline%3B%20filename%3D%22OCT2017.tar.gz%22%3B%20filename%2A%3DUTF-8%27%27OCT2017.tar.gz&response-content-type=application%2Fgzip&AWSAccessKeyId=AKIAI6DZXOGICLKVGYEA&Expires=1563480102&Signature=BLf3WoP99QCj2vAsYojUdaFrxOs%3D [following] --2019-07-18 19:01:42-- https://com-mendeley-internal.s3.amazonaws.com/platform/rdm/production/a2609dd9-bc04-4417-9e77-2406e74a4f44?response-content-disposition=inline%3B%20filename%3D%22OCT2017.tar.gz%22%3B%20filename%2A%3DUTF-8%27%27OCT2017.tar.gz&response-content-type=application%2Fgzip&AWSAccessKeyId=AKIAI6DZXOGICLKVGYEA&Expires=1563480102&Signature=BLf3WoP99QCj2vAsYojUdaFrxOs%3D Resolving com-mendeley-internal.s3.amazonaws.com (com-mendeley-internal.s3.amazonaws.com)... 52.218.57.138 Connecting to com-mendeley-internal.s3.amazonaws.com (com-mendeley-internal.s3.amazonaws.com)|52.218.57.138|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 5793183169 (5.4G) [application/gzip] Saving to: ‘OCT2017.tar.gz’ OCT2017.tar.gz 100%[===================>] 5.39G 53.4MB/s in 1m 50s 2019-07-18 19:03:33 (50.0 MB/s) - ‘OCT2017.tar.gz’ saved [5793183169/5793183169]
!tar -zxf OCT2017.tar.gz
!pip install -q tensorflow-gpu==2.0.0-beta1
|████████████████████████████████| 348.9MB 99kB/s |████████████████████████████████| 3.1MB 2.0MB/s |████████████████████████████████| 501kB 57.0MB/s
import tensorflow as tf
tf.__version__
'2.0.0-beta1'
from tensorflow.keras.optimizers import SGD
from tensorflow.keras import layers
from tensorflow.python.keras.utils.data_utils import Sequence
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, LearningRateScheduler
from keras.preprocessing import image
import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import os
import PIL
import pathlib
import pandas as pd
import cv2
from imgaug import augmenters as iaa
import imgaug as ia
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix
import itertools
Using TensorFlow backend.
batch_size = 32
img_size = 224
input_img_size = (224, 224, 3)
train_generator = image.ImageDataGenerator(
preprocessing_function=preprocess_input,
rotation_range=20,
zoom_range=0.05,
horizontal_flip=True,
fill_mode="nearest").flow_from_directory(
"OCT2017/train",
target_size=(img_size, img_size),
color_mode="rgb",
shuffle=True,
batch_size=batch_size)
val_generator = image.ImageDataGenerator(
preprocessing_function=preprocess_input).flow_from_directory(
"OCT2017/test",
target_size=(img_size, img_size),
color_mode="rgb",
shuffle=False,
batch_size=batch_size)
Found 83484 images belonging to 4 classes. Found 1000 images belonging to 4 classes.
def create_model():
mobile_model = MobileNetV2(
weights=None,
input_shape=input_img_size,
alpha=1,
include_top=False)
for layer in mobile_model.layers:
layer.trainable = True
model = mobile_model.layers[-3].output
model = layers.GlobalAveragePooling2D()(model)
model = layers.Dense(num_classes, activation="softmax", kernel_initializer='uniform')(model)
model = Model(inputs=mobile_model.input, outputs=model)
return model
def polynomial_decay(epoch):
power = 1.0
alpha = learning_rate * (1 - (epoch / float(epochs))) ** power
return alpha
learning_rate = 3e-4
epochs = 10
num_classes = 4
weights_name = "epoch={epoch:02d}|accuracy={acc:.4f}|val_accuracy={val_acc:.4f}.h5"
checkpoint = ModelCheckpoint(weights_name, monitor="val_acc", verbose=1, save_best_only=True,
save_weights_only=True, mode="max", period=1)
lr_decay = LearningRateScheduler(polynomial_decay)
optimizer = SGD(lr=learning_rate, momentum=0.9)
WARNING: Logging before flag parsing goes to stderr. W0718 19:05:31.943247 139825010534272 callbacks.py:859] `period` argument is deprecated. Please use `save_freq` to specify the frequency in number of samples seen.
model = None
model = create_model()
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['acc'])
train_steps = int(83484 / batch_size)
val_steps = int(1000 / batch_size)
train_steps, val_steps
(2608, 31)
trained_model = model.fit_generator(train_generator,
epochs=epochs,
steps_per_epoch=train_steps,
callbacks=[checkpoint, lr_decay],
validation_data=val_generator,
validation_steps=val_steps,
verbose=1)
Epoch 1/10
W0624 01:20:55.483824 139892176177024 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.where in 2.0, which has the same broadcast rule as np.where
2607/2608 [============================>.] - ETA: 0s - loss: 0.7372 - acc: 0.7142 Epoch 00001: val_acc improved from -inf to 0.24395, saving model to epoch=01|accuracy=0.7143|val_accuracy=0.2440.h5 2608/2608 [==============================] - 1628s 624ms/step - loss: 0.7371 - acc: 0.7143 - val_loss: 1.5402 - val_acc: 0.2440 Epoch 2/10 2607/2608 [============================>.] - ETA: 0s - loss: 0.3870 - acc: 0.8628 Epoch 00002: val_acc did not improve from 0.24395 2608/2608 [==============================] - 1567s 601ms/step - loss: 0.3870 - acc: 0.8628 - val_loss: 1.7461 - val_acc: 0.2440 Epoch 3/10 2607/2608 [============================>.] - ETA: 0s - loss: 0.2993 - acc: 0.8955 Epoch 00003: val_acc improved from 0.24395 to 0.86694, saving model to epoch=03|accuracy=0.8954|val_accuracy=0.8669.h5 2608/2608 [==============================] - 1543s 592ms/step - loss: 0.2993 - acc: 0.8954 - val_loss: 0.3796 - val_acc: 0.8669 Epoch 4/10 2607/2608 [============================>.] - ETA: 0s - loss: 0.2547 - acc: 0.9113 Epoch 00004: val_acc improved from 0.86694 to 0.91331, saving model to epoch=04|accuracy=0.9114|val_accuracy=0.9133.h5 2608/2608 [==============================] - 1536s 589ms/step - loss: 0.2547 - acc: 0.9114 - val_loss: 0.2160 - val_acc: 0.9133 Epoch 5/10 2607/2608 [============================>.] - ETA: 0s - loss: 0.2307 - acc: 0.9213 Epoch 00005: val_acc improved from 0.91331 to 0.94657, saving model to epoch=05|accuracy=0.9213|val_accuracy=0.9466.h5 2608/2608 [==============================] - 1537s 590ms/step - loss: 0.2308 - acc: 0.9213 - val_loss: 0.1943 - val_acc: 0.9466 Epoch 6/10 2607/2608 [============================>.] - ETA: 0s - loss: 0.2129 - acc: 0.9264 Epoch 00006: val_acc improved from 0.94657 to 0.95665, saving model to epoch=06|accuracy=0.9265|val_accuracy=0.9567.h5 2608/2608 [==============================] - 1531s 587ms/step - loss: 0.2128 - acc: 0.9265 - val_loss: 0.1287 - val_acc: 0.9567 Epoch 7/10 2607/2608 [============================>.] - ETA: 0s - loss: 0.1998 - acc: 0.9314 Epoch 00007: val_acc improved from 0.95665 to 0.97379, saving model to epoch=07|accuracy=0.9314|val_accuracy=0.9738.h5 2608/2608 [==============================] - 1535s 589ms/step - loss: 0.1998 - acc: 0.9314 - val_loss: 0.0855 - val_acc: 0.9738 Epoch 8/10 2607/2608 [============================>.] - ETA: 0s - loss: 0.1892 - acc: 0.9351 Epoch 00008: val_acc improved from 0.97379 to 0.97681, saving model to epoch=08|accuracy=0.9350|val_accuracy=0.9768.h5 2608/2608 [==============================] - 1536s 589ms/step - loss: 0.1893 - acc: 0.9350 - val_loss: 0.0944 - val_acc: 0.9768 Epoch 9/10 2607/2608 [============================>.] - ETA: 0s - loss: 0.1813 - acc: 0.9367 Epoch 00009: val_acc improved from 0.97681 to 0.98488, saving model to epoch=09|accuracy=0.9367|val_accuracy=0.9849.h5 2608/2608 [==============================] - 1536s 589ms/step - loss: 0.1813 - acc: 0.9367 - val_loss: 0.0778 - val_acc: 0.9849 Epoch 10/10 2607/2608 [============================>.] - ETA: 0s - loss: 0.1743 - acc: 0.9400 Epoch 00010: val_acc did not improve from 0.98488 2608/2608 [==============================] - 1540s 591ms/step - loss: 0.1744 - acc: 0.9400 - val_loss: 0.0723 - val_acc: 0.9829
def plot_validation_training(metric, trained_model):
validation_metric = trained_model.history[f'val_{metric}']
training_metric = trained_model.history[metric]
epochs = range(len(training_metric))
plt.plot(epochs, training_metric, 'b', label=f'Training {metric}')
plt.plot(epochs, validation_metric, 'r', label=f'Validation {metric}')
plt.ylim(bottom=0)
plt.xlabel('Epochs ', fontsize=16)
plt.ylabel(metric, fontsize=16)
loc = 'upper right' if metric == "loss" else 'lower right'
plt.legend(loc=loc)
plt.title(f'Training and validation {metric}', fontsize = 20)
plt.show()
plot_validation_training("loss", trained_model)
plot_validation_training("acc", trained_model)
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
val_generator.reset()
val_generator.class_indices
{'CNV': 0, 'DME': 1, 'DRUSEN': 2, 'NORMAL': 3}
def test_model(generator):
predicted_classes = model.predict_generator(generator, verbose=1)
max_class_predicted = np.argmax(predicted_classes, axis=1)
classes_names = ["CNV", "DME", "DRUSEN", "NORMAL"]
labels = val_generator.classes
cm = confusion_matrix(labels, max_class_predicted, labels=range(num_classes))
plot_confusion_matrix(cm, classes_names)
sensitivity = cm[0, 0] / (cm[0, 0] + cm[0, 1])
print(f'sensitivity: {sensitivity}')
specificity = cm[1, 1] / (cm[1, 1] + cm[1, 0])
print(f'specifity: {specificity}')
incorrect_labels = np.where(max_class_predicted != labels)[0]
return max_class_predicted, incorrect_labels
train_class_predicted, train_incorrect_labels = test_model(val_generator)
32/32 [==============================] - 6s 199ms/step sensitivity: 0.996 specifity: 0.967741935483871
def get_activation_map(image_path, image_class_vector):
image_loaded = PIL.Image.open(image_path)
image_loaded = image_loaded.resize((img_size, img_size))
image_loaded = np.asarray(image_loaded)
if len(image_loaded.shape) < 3:
image_loaded = np.stack([image_loaded.copy()] * 3, axis=2)
preprocessed_image = preprocess_input(image_loaded)
preprocessed_image = np.expand_dims(preprocessed_image, axis=0)
image_class = np.argmax(image_class_vector)
class_weights = model.layers[-1].get_weights()[0]
final_conv_layer = model.layers[-3]
get_output = tf.keras.backend.function([model.layers[0].input],
[final_conv_layer.output, model.layers[-1].output])
[conv_outputs, predictions] = get_output(preprocessed_image)
conv_outputs = conv_outputs[0, :, :, :]
cam = np.zeros(dtype=np.float32, shape=conv_outputs.shape[0:2])
for index, weight in enumerate(class_weights[:, image_class]):
cam += weight * conv_outputs[:, :, index]
class_predicted = np.argmax(predictions[0])
predictions = f'Class predicted: {class_predicted} | Real class: {image_class}'
cam /= np.max(cam)
cam = cv2.resize(cam, (img_size, img_size))
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
heatmap[np.where(cam < 0.2)] = 0
img = heatmap * 0.5 + image_loaded
cv2.imwrite("heatmap.jpg", img)
heatmap = mpimg.imread("heatmap.jpg")
scaled_image = (((img - img.min()) * 255) / (img.max() - img.min())).astype(np.uint8)
fig, ax = plt.subplots(1, 2, figsize=(10, 10))
ax[0].imshow(image_loaded)
ax[0].set_title('Original image')
ax[1].imshow(heatmap)
ax[1].set_title(predictions)
model.load_weights("epoch=09_accuracy=0.9367_val_accuracy=0.9849.h5")
os.listdir("OCT2017/test")
['NORMAL', 'DRUSEN', 'DME', 'CNV']
val_generator.class_indices
{'CNV': 0, 'DME': 1, 'DRUSEN': 2, 'NORMAL': 3}
normal_path = "OCT2017/test/NORMAL"
drusen_path = "OCT2017/test/DRUSEN"
dme_path = "OCT2017/test/DME"
cnv_path = "OCT2017/test/CNV"
get_activation_map(f'{normal_path}/{os.listdir(normal_path)[0]}', [0, 0, 0, 1])
get_activation_map(f'{drusen_path}/{os.listdir(drusen_path)[0]}', [0, 0, 1, 0])
get_activation_map(f'{dme_path}/{os.listdir(dme_path)[1]}', [0, 1, 0, 0])
get_activation_map(f'{cnv_path}/{os.listdir(cnv_path)[0]}', [1, 0, 0, 0])