Learning a deep neural network for ground segmentation.
from utils.get_dataset import get_cmu_corridor_dataset
get_cmu_corridor_dataset(dataset_path='./dataset')
Compressing dataset: 100%|██████████| 2906/2906 [00:01<00:00, 1495.70it/s]
from utils.load_dataset import load_cmu_corridor_dataset
X_train, X_test, y_train, y_test = load_cmu_corridor_dataset(dataset_path='./dataset/cmu_corridor_dataset',
train_test_split=True,
image_size=(240, 320),
gray=True,
verbose=True)
print('X_train: %s, y_train: %s' % (X_train.shape, y_train.shape))
print('X_test: %s, y_test: %s' % (X_test.shape, y_test.shape))
Loading images: 100%|██████████| 967/967 [00:01<00:00, 551.44it/s] Loading labels: 100%|██████████| 967/967 [00:00<00:00, 1785.36it/s]
X_train: (725, 240, 320, 1), y_train: (725, 240, 320, 1) X_test: (242, 240, 320, 1), y_test: (242, 240, 320, 1)
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = [5, 8]
def plot_image(ax, img, title):
ax.axis('off')
ax.set_title(title)
if img.shape[-1] == 1:
img = img[..., 0]
ax.imshow(img, cmap='gray')
else:
ax.imshow(img)
def plot_label(ax, label, title):
ax.axis('off')
ax.set_title(title)
label = label[..., 0]
ax.imshow(label, cmap='gray')
NUM_PLOTTING = 4
fig = plt.figure("Insight dataset")
images, labels = X_train, y_train
for index in range(NUM_PLOTTING):
image, label = images[index], labels[index]
image_title = '' if index != 0 else 'Raw image'
label_title = '' if index != 0 else 'Ground segmentation'
ax = fig.add_subplot(NUM_PLOTTING, 2, index*2 + 1)
plot_image(ax, image, title=image_title)
ax = fig.add_subplot(NUM_PLOTTING, 2, index*2 + 2)
plot_label(ax, label, title=label_title)
plt.show()
img_height, img_width, img_channel = X_train[0].shape
from tensorflow.keras import backend as K
import tensorflow as tf
from tensorflow.keras import metrics
import numpy as np
# Compatible with tensorflow backend
def focal_loss(gamma=2., alpha=.75):
def focal_loss_fixed(y_true, y_pred):
pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
return -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1)) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0))
return focal_loss_fixed
def mean_IoU(target, pred):
m = metrics.MeanIoU(num_classes=2)
return m(target, pred > 0.5)
from tensorflow.keras import Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import metrics
from utils.unet import get_unet
input_img = Input((img_height, img_width, img_channel), name='img')
model = get_unet(input_img, n_filters=16, dropout=0.05, batchnorm=True)
model.compile(optimizer=Adam(), loss="binary_crossentropy",
metrics=["accuracy", mean_IoU])
# model.compile(optimizer=Adam(), loss=focal_loss(gamma=2, alpha=.75),
# metrics=["accuracy", mean_IoU])
model.summary()
Model: "unet" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== img (InputLayer) [(None, 240, 320, 1) 0 __________________________________________________________________________________________________ conv2d_1 (Conv2D) (None, 240, 320, 16) 160 img[0][0] __________________________________________________________________________________________________ batch_normalization_1 (BatchNor (None, 240, 320, 16) 64 conv2d_1[0][0] __________________________________________________________________________________________________ activation_1 (Activation) (None, 240, 320, 16) 0 batch_normalization_1[0][0] __________________________________________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 120, 160, 16) 0 activation_1[0][0] __________________________________________________________________________________________________ dropout (Dropout) (None, 120, 160, 16) 0 max_pooling2d[0][0] __________________________________________________________________________________________________ conv2d_3 (Conv2D) (None, 120, 160, 32) 4640 dropout[0][0] __________________________________________________________________________________________________ batch_normalization_3 (BatchNor (None, 120, 160, 32) 128 conv2d_3[0][0] __________________________________________________________________________________________________ activation_3 (Activation) (None, 120, 160, 32) 0 batch_normalization_3[0][0] __________________________________________________________________________________________________ max_pooling2d_1 (MaxPooling2D) (None, 60, 80, 32) 0 activation_3[0][0] __________________________________________________________________________________________________ dropout_1 (Dropout) (None, 60, 80, 32) 0 max_pooling2d_1[0][0] __________________________________________________________________________________________________ conv2d_5 (Conv2D) (None, 60, 80, 64) 18496 dropout_1[0][0] __________________________________________________________________________________________________ batch_normalization_5 (BatchNor (None, 60, 80, 64) 256 conv2d_5[0][0] __________________________________________________________________________________________________ activation_5 (Activation) (None, 60, 80, 64) 0 batch_normalization_5[0][0] __________________________________________________________________________________________________ max_pooling2d_2 (MaxPooling2D) (None, 30, 40, 64) 0 activation_5[0][0] __________________________________________________________________________________________________ dropout_2 (Dropout) (None, 30, 40, 64) 0 max_pooling2d_2[0][0] __________________________________________________________________________________________________ conv2d_7 (Conv2D) (None, 30, 40, 128) 73856 dropout_2[0][0] __________________________________________________________________________________________________ batch_normalization_7 (BatchNor (None, 30, 40, 128) 512 conv2d_7[0][0] __________________________________________________________________________________________________ activation_7 (Activation) (None, 30, 40, 128) 0 batch_normalization_7[0][0] __________________________________________________________________________________________________ max_pooling2d_3 (MaxPooling2D) (None, 15, 20, 128) 0 activation_7[0][0] __________________________________________________________________________________________________ dropout_3 (Dropout) (None, 15, 20, 128) 0 max_pooling2d_3[0][0] __________________________________________________________________________________________________ conv2d_9 (Conv2D) (None, 15, 20, 256) 295168 dropout_3[0][0] __________________________________________________________________________________________________ batch_normalization_9 (BatchNor (None, 15, 20, 256) 1024 conv2d_9[0][0] __________________________________________________________________________________________________ activation_9 (Activation) (None, 15, 20, 256) 0 batch_normalization_9[0][0] __________________________________________________________________________________________________ conv2d_transpose (Conv2DTranspo (None, 30, 40, 128) 295040 activation_9[0][0] __________________________________________________________________________________________________ concatenate (Concatenate) (None, 30, 40, 256) 0 conv2d_transpose[0][0] activation_7[0][0] __________________________________________________________________________________________________ dropout_4 (Dropout) (None, 30, 40, 256) 0 concatenate[0][0] __________________________________________________________________________________________________ conv2d_11 (Conv2D) (None, 30, 40, 128) 295040 dropout_4[0][0] __________________________________________________________________________________________________ batch_normalization_11 (BatchNo (None, 30, 40, 128) 512 conv2d_11[0][0] __________________________________________________________________________________________________ activation_11 (Activation) (None, 30, 40, 128) 0 batch_normalization_11[0][0] __________________________________________________________________________________________________ conv2d_transpose_1 (Conv2DTrans (None, 60, 80, 64) 73792 activation_11[0][0] __________________________________________________________________________________________________ concatenate_1 (Concatenate) (None, 60, 80, 128) 0 conv2d_transpose_1[0][0] activation_5[0][0] __________________________________________________________________________________________________ dropout_5 (Dropout) (None, 60, 80, 128) 0 concatenate_1[0][0] __________________________________________________________________________________________________ conv2d_13 (Conv2D) (None, 60, 80, 64) 73792 dropout_5[0][0] __________________________________________________________________________________________________ batch_normalization_13 (BatchNo (None, 60, 80, 64) 256 conv2d_13[0][0] __________________________________________________________________________________________________ activation_13 (Activation) (None, 60, 80, 64) 0 batch_normalization_13[0][0] __________________________________________________________________________________________________ conv2d_transpose_2 (Conv2DTrans (None, 120, 160, 32) 18464 activation_13[0][0] __________________________________________________________________________________________________ concatenate_2 (Concatenate) (None, 120, 160, 64) 0 conv2d_transpose_2[0][0] activation_3[0][0] __________________________________________________________________________________________________ dropout_6 (Dropout) (None, 120, 160, 64) 0 concatenate_2[0][0] __________________________________________________________________________________________________ conv2d_15 (Conv2D) (None, 120, 160, 32) 18464 dropout_6[0][0] __________________________________________________________________________________________________ batch_normalization_15 (BatchNo (None, 120, 160, 32) 128 conv2d_15[0][0] __________________________________________________________________________________________________ activation_15 (Activation) (None, 120, 160, 32) 0 batch_normalization_15[0][0] __________________________________________________________________________________________________ conv2d_transpose_3 (Conv2DTrans (None, 240, 320, 16) 4624 activation_15[0][0] __________________________________________________________________________________________________ concatenate_3 (Concatenate) (None, 240, 320, 32) 0 conv2d_transpose_3[0][0] activation_1[0][0] __________________________________________________________________________________________________ dropout_7 (Dropout) (None, 240, 320, 32) 0 concatenate_3[0][0] __________________________________________________________________________________________________ conv2d_17 (Conv2D) (None, 240, 320, 16) 4624 dropout_7[0][0] __________________________________________________________________________________________________ batch_normalization_17 (BatchNo (None, 240, 320, 16) 64 conv2d_17[0][0] __________________________________________________________________________________________________ activation_17 (Activation) (None, 240, 320, 16) 0 batch_normalization_17[0][0] __________________________________________________________________________________________________ conv2d_18 (Conv2D) (None, 240, 320, 1) 17 activation_17[0][0] ================================================================================================== Total params: 1,179,121 Trainable params: 1,177,649 Non-trainable params: 1,472 __________________________________________________________________________________________________
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
callbacks = [
EarlyStopping(patience=10, verbose=1),
ReduceLROnPlateau(factor=0.1, patience=5, min_lr=0.00001, verbose=1),
ModelCheckpoint('models/model-gs-unet.h5', verbose=1, save_best_only=True, save_weights_only=False),
TensorBoard(log_dir='.logs')
]
results = model.fit(X_train, y_train, batch_size=16, epochs=50, callbacks=callbacks,\
validation_data=(X_test, y_test))
WARNING: Logging before flag parsing goes to stderr. W0718 12:10:35.686528 140018677303104 deprecation.py:323] From /home/kerry/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.where in 2.0, which has the same broadcast rule as np.where
Train on 725 samples, validate on 242 samples Epoch 1/50 720/725 [============================>.] - ETA: 0s - loss: 0.5008 - accuracy: 0.7831 - mean_IoU: 0.4471 Epoch 00001: val_loss improved from inf to 0.67455, saving model to models/model-gs-unet.h5 725/725 [==============================] - 14s 20ms/sample - loss: 0.5000 - accuracy: 0.7835 - mean_IoU: 0.4498 - val_loss: 0.6746 - val_accuracy: 0.8255 - val_mean_IoU: 0.5803 Epoch 2/50 720/725 [============================>.] - ETA: 0s - loss: 0.3085 - accuracy: 0.9031 - mean_IoU: 0.6228 Epoch 00002: val_loss improved from 0.67455 to 0.43366, saving model to models/model-gs-unet.h5 725/725 [==============================] - 11s 15ms/sample - loss: 0.3086 - accuracy: 0.9030 - mean_IoU: 0.6234 - val_loss: 0.4337 - val_accuracy: 0.8157 - val_mean_IoU: 0.6415 Epoch 3/50 720/725 [============================>.] - ETA: 0s - loss: 0.2576 - accuracy: 0.9147 - mean_IoU: 0.6545 Epoch 00003: val_loss did not improve from 0.43366 725/725 [==============================] - 10s 14ms/sample - loss: 0.2580 - accuracy: 0.9144 - mean_IoU: 0.6548 - val_loss: 0.8242 - val_accuracy: 0.6535 - val_mean_IoU: 0.6575 Epoch 4/50 720/725 [============================>.] - ETA: 0s - loss: 0.2305 - accuracy: 0.9193 - mean_IoU: 0.6606 Epoch 00004: val_loss did not improve from 0.43366 725/725 [==============================] - 10s 14ms/sample - loss: 0.2304 - accuracy: 0.9193 - mean_IoU: 0.6608 - val_loss: 0.7008 - val_accuracy: 0.6848 - val_mean_IoU: 0.6643 Epoch 5/50 720/725 [============================>.] - ETA: 0s - loss: 0.1978 - accuracy: 0.9308 - mean_IoU: 0.6702 Epoch 00005: val_loss did not improve from 0.43366 725/725 [==============================] - 10s 14ms/sample - loss: 0.1978 - accuracy: 0.9308 - mean_IoU: 0.6704 - val_loss: 0.5901 - val_accuracy: 0.7582 - val_mean_IoU: 0.6759 Epoch 6/50 720/725 [============================>.] - ETA: 0s - loss: 0.1891 - accuracy: 0.9314 - mean_IoU: 0.6809 Epoch 00006: val_loss did not improve from 0.43366 725/725 [==============================] - 10s 14ms/sample - loss: 0.1888 - accuracy: 0.9316 - mean_IoU: 0.6811 - val_loss: 0.5757 - val_accuracy: 0.7628 - val_mean_IoU: 0.6856 Epoch 7/50 720/725 [============================>.] - ETA: 0s - loss: 0.1821 - accuracy: 0.9313 - mean_IoU: 0.6902 Epoch 00007: val_loss improved from 0.43366 to 0.39979, saving model to models/model-gs-unet.h5 725/725 [==============================] - 11s 15ms/sample - loss: 0.1817 - accuracy: 0.9314 - mean_IoU: 0.6904 - val_loss: 0.3998 - val_accuracy: 0.8377 - val_mean_IoU: 0.6960 Epoch 8/50 720/725 [============================>.] - ETA: 0s - loss: 0.1663 - accuracy: 0.9382 - mean_IoU: 0.7012 Epoch 00008: val_loss improved from 0.39979 to 0.39095, saving model to models/model-gs-unet.h5 725/725 [==============================] - 11s 14ms/sample - loss: 0.1667 - accuracy: 0.9380 - mean_IoU: 0.7014 - val_loss: 0.3909 - val_accuracy: 0.8437 - val_mean_IoU: 0.7066 Epoch 9/50 720/725 [============================>.] - ETA: 0s - loss: 0.1549 - accuracy: 0.9423 - mean_IoU: 0.7112 Epoch 00009: val_loss improved from 0.39095 to 0.30219, saving model to models/model-gs-unet.h5 725/725 [==============================] - 11s 15ms/sample - loss: 0.1550 - accuracy: 0.9422 - mean_IoU: 0.7113 - val_loss: 0.3022 - val_accuracy: 0.8775 - val_mean_IoU: 0.7166 Epoch 10/50 720/725 [============================>.] - ETA: 0s - loss: 0.1466 - accuracy: 0.9447 - mean_IoU: 0.7214 Epoch 00010: val_loss did not improve from 0.30219 725/725 [==============================] - 10s 14ms/sample - loss: 0.1466 - accuracy: 0.9448 - mean_IoU: 0.7215 - val_loss: 0.3188 - val_accuracy: 0.8721 - val_mean_IoU: 0.7255 Epoch 11/50 720/725 [============================>.] - ETA: 0s - loss: 0.1385 - accuracy: 0.9481 - mean_IoU: 0.7295 Epoch 00011: val_loss did not improve from 0.30219 725/725 [==============================] - 10s 14ms/sample - loss: 0.1381 - accuracy: 0.9483 - mean_IoU: 0.7296 - val_loss: 0.5125 - val_accuracy: 0.8174 - val_mean_IoU: 0.7325 Epoch 12/50 720/725 [============================>.] - ETA: 0s - loss: 0.1347 - accuracy: 0.9486 - mean_IoU: 0.7352 Epoch 00012: val_loss improved from 0.30219 to 0.26944, saving model to models/model-gs-unet.h5 725/725 [==============================] - 11s 15ms/sample - loss: 0.1342 - accuracy: 0.9488 - mean_IoU: 0.7353 - val_loss: 0.2694 - val_accuracy: 0.9013 - val_mean_IoU: 0.7391 Epoch 13/50 720/725 [============================>.] - ETA: 0s - loss: 0.1290 - accuracy: 0.9506 - mean_IoU: 0.7426 Epoch 00013: val_loss improved from 0.26944 to 0.19169, saving model to models/model-gs-unet.h5 725/725 [==============================] - 10s 14ms/sample - loss: 0.1289 - accuracy: 0.9506 - mean_IoU: 0.7427 - val_loss: 0.1917 - val_accuracy: 0.9251 - val_mean_IoU: 0.7465 Epoch 14/50 720/725 [============================>.] - ETA: 0s - loss: 0.1337 - accuracy: 0.9478 - mean_IoU: 0.7498 Epoch 00014: val_loss improved from 0.19169 to 0.18700, saving model to models/model-gs-unet.h5 725/725 [==============================] - 10s 14ms/sample - loss: 0.1344 - accuracy: 0.9476 - mean_IoU: 0.7499 - val_loss: 0.1870 - val_accuracy: 0.9274 - val_mean_IoU: 0.7528 Epoch 15/50 720/725 [============================>.] - ETA: 0s - loss: 0.1306 - accuracy: 0.9499 - mean_IoU: 0.7558 Epoch 00015: val_loss did not improve from 0.18700 725/725 [==============================] - 10s 14ms/sample - loss: 0.1303 - accuracy: 0.9501 - mean_IoU: 0.7559 - val_loss: 0.2278 - val_accuracy: 0.9132 - val_mean_IoU: 0.7585 Epoch 16/50 720/725 [============================>.] - ETA: 0s - loss: 0.1163 - accuracy: 0.9555 - mean_IoU: 0.7614 Epoch 00016: val_loss did not improve from 0.18700 725/725 [==============================] - 10s 14ms/sample - loss: 0.1166 - accuracy: 0.9554 - mean_IoU: 0.7614 - val_loss: 0.1888 - val_accuracy: 0.9268 - val_mean_IoU: 0.7641 Epoch 17/50 720/725 [============================>.] - ETA: 0s - loss: 0.1129 - accuracy: 0.9562 - mean_IoU: 0.7670 Epoch 00017: val_loss did not improve from 0.18700 725/725 [==============================] - 10s 14ms/sample - loss: 0.1129 - accuracy: 0.9563 - mean_IoU: 0.7670 - val_loss: 0.2084 - val_accuracy: 0.9233 - val_mean_IoU: 0.7695 Epoch 18/50 720/725 [============================>.] - ETA: 0s - loss: 0.1143 - accuracy: 0.9558 - mean_IoU: 0.7717 Epoch 00018: val_loss did not improve from 0.18700 725/725 [==============================] - 10s 14ms/sample - loss: 0.1150 - accuracy: 0.9556 - mean_IoU: 0.7718 - val_loss: 0.3585 - val_accuracy: 0.8675 - val_mean_IoU: 0.7733 Epoch 19/50 720/725 [============================>.] - ETA: 0s - loss: 0.1091 - accuracy: 0.9579 - mean_IoU: 0.7749 Epoch 00019: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513. Epoch 00019: val_loss did not improve from 0.18700 725/725 [==============================] - 10s 14ms/sample - loss: 0.1094 - accuracy: 0.9577 - mean_IoU: 0.7750 - val_loss: 0.2544 - val_accuracy: 0.9046 - val_mean_IoU: 0.7769 Epoch 20/50 720/725 [============================>.] - ETA: 0s - loss: 0.0964 - accuracy: 0.9637 - mean_IoU: 0.7789 Epoch 00020: val_loss improved from 0.18700 to 0.16215, saving model to models/model-gs-unet.h5 725/725 [==============================] - 10s 14ms/sample - loss: 0.0963 - accuracy: 0.9638 - mean_IoU: 0.7789 - val_loss: 0.1622 - val_accuracy: 0.9348 - val_mean_IoU: 0.7813 Epoch 21/50 720/725 [============================>.] - ETA: 0s - loss: 0.0885 - accuracy: 0.9666 - mean_IoU: 0.7837 Epoch 00021: val_loss did not improve from 0.16215 725/725 [==============================] - 10s 14ms/sample - loss: 0.0886 - accuracy: 0.9666 - mean_IoU: 0.7837 - val_loss: 0.1775 - val_accuracy: 0.9279 - val_mean_IoU: 0.7858 Epoch 22/50 720/725 [============================>.] - ETA: 0s - loss: 0.0844 - accuracy: 0.9687 - mean_IoU: 0.7880 Epoch 00022: val_loss did not improve from 0.16215 725/725 [==============================] - 10s 14ms/sample - loss: 0.0845 - accuracy: 0.9687 - mean_IoU: 0.7880 - val_loss: 0.1651 - val_accuracy: 0.9323 - val_mean_IoU: 0.7901 Epoch 23/50 720/725 [============================>.] - ETA: 0s - loss: 0.0838 - accuracy: 0.9690 - mean_IoU: 0.7921 Epoch 00023: val_loss did not improve from 0.16215 725/725 [==============================] - 10s 14ms/sample - loss: 0.0842 - accuracy: 0.9689 - mean_IoU: 0.7921 - val_loss: 0.1622 - val_accuracy: 0.9328 - val_mean_IoU: 0.7940 Epoch 24/50 720/725 [============================>.] - ETA: 0s - loss: 0.0845 - accuracy: 0.9686 - mean_IoU: 0.7959 Epoch 00024: val_loss did not improve from 0.16215 725/725 [==============================] - 10s 14ms/sample - loss: 0.0843 - accuracy: 0.9687 - mean_IoU: 0.7959 - val_loss: 0.1669 - val_accuracy: 0.9315 - val_mean_IoU: 0.7978 Epoch 25/50 720/725 [============================>.] - ETA: 0s - loss: 0.0829 - accuracy: 0.9689 - mean_IoU: 0.7994 Epoch 00025: ReduceLROnPlateau reducing learning rate to 1.0000000474974514e-05. Epoch 00025: val_loss did not improve from 0.16215 725/725 [==============================] - 10s 14ms/sample - loss: 0.0829 - accuracy: 0.9690 - mean_IoU: 0.7995 - val_loss: 0.1628 - val_accuracy: 0.9334 - val_mean_IoU: 0.8011 Epoch 26/50 720/725 [============================>.] - ETA: 0s - loss: 0.0791 - accuracy: 0.9708 - mean_IoU: 0.8028 Epoch 00026: val_loss did not improve from 0.16215 725/725 [==============================] - 10s 14ms/sample - loss: 0.0791 - accuracy: 0.9708 - mean_IoU: 0.8028 - val_loss: 0.1805 - val_accuracy: 0.9260 - val_mean_IoU: 0.8043 Epoch 27/50 720/725 [============================>.] - ETA: 0s - loss: 0.0808 - accuracy: 0.9703 - mean_IoU: 0.8059 Epoch 00027: val_loss did not improve from 0.16215 725/725 [==============================] - 10s 14ms/sample - loss: 0.0807 - accuracy: 0.9704 - mean_IoU: 0.8059 - val_loss: 0.1920 - val_accuracy: 0.9214 - val_mean_IoU: 0.8073 Epoch 28/50 720/725 [============================>.] - ETA: 0s - loss: 0.0802 - accuracy: 0.9704 - mean_IoU: 0.8086 Epoch 00028: val_loss did not improve from 0.16215 725/725 [==============================] - 10s 14ms/sample - loss: 0.0802 - accuracy: 0.9704 - mean_IoU: 0.8086 - val_loss: 0.2008 - val_accuracy: 0.9183 - val_mean_IoU: 0.8098 Epoch 29/50 720/725 [============================>.] - ETA: 0s - loss: 0.0786 - accuracy: 0.9712 - mean_IoU: 0.8111 Epoch 00029: val_loss did not improve from 0.16215 725/725 [==============================] - 10s 14ms/sample - loss: 0.0785 - accuracy: 0.9713 - mean_IoU: 0.8112 - val_loss: 0.2049 - val_accuracy: 0.9170 - val_mean_IoU: 0.8124 Epoch 30/50 720/725 [============================>.] - ETA: 0s - loss: 0.0788 - accuracy: 0.9711 - mean_IoU: 0.8135 Epoch 00030: ReduceLROnPlateau reducing learning rate to 1e-05. Epoch 00030: val_loss did not improve from 0.16215 725/725 [==============================] - 10s 14ms/sample - loss: 0.0787 - accuracy: 0.9711 - mean_IoU: 0.8135 - val_loss: 0.2074 - val_accuracy: 0.9163 - val_mean_IoU: 0.8147 Epoch 00030: early stopping
plt.figure(figsize=(8, 8))
plt.title("Learning curve")
plt.plot(results.history["loss"], label="loss")
plt.plot(results.history["val_loss"], label="val_loss")
plt.plot( np.argmin(results.history["val_loss"]), np.min(results.history["val_loss"]), marker="x", color="r", label="best model")
plt.xlabel("Epochs")
plt.ylabel("log_loss")
plt.legend();
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import metrics
from tensorflow.keras.models import load_model
model = load_model('models/model-gs-unet.h5', compile=False)
model.compile(optimizer=Adam(), loss="binary_crossentropy", metrics=["accuracy", mean_IoU])
# Evaluate on train set (this must be equals to the best log_loss)
model.evaluate(X_train, y_train, verbose=1)
725/725 [==============================] - 4s 5ms/sample - loss: 0.1232 - accuracy: 0.9499 - mean_IoU: 0.8752
[0.12322666003786284, 0.9499445, 0.875184]
# Evaluate on test set (this must be equals to the best log_loss)
model.evaluate(X_test, y_test, verbose=1)
242/242 [==============================] - 1s 5ms/sample - loss: 0.1622 - accuracy: 0.9348 - mean_IoU: 0.8639
[0.16215478450306192, 0.9348493, 0.863872]
# Predict on train and test
preds_train = model.predict(X_train, verbose=1)
preds_test = model.predict(X_test, verbose=1)
725/725 [==============================] - 2s 2ms/sample 242/242 [==============================] - 1s 2ms/sample
from utils.plot import plot_samples
plot_samples(X_train, y_train, preds_train, num_samples=4, prob=0.5, seed=1)
plot_samples(X_test, y_test, preds_test, num_samples=4, prob=0.5, seed=1)
import glob
from path import Path
from tqdm import tqdm_notebook
import tensorflow as tf
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array, load_img
DEMO_PATH = Path('./demo')
image_files = glob.glob(DEMO_PATH / '*.*')
X_misc = np.zeros((len(image_files), img_height, img_width, img_channel), dtype=np.float32)
for i, image_file in enumerate(tqdm_notebook(image_files, total=len(image_files), desc="Loading images")):
img = load_img(image_file, grayscale=False if img_channel == 3 else True)
x_img = img_to_array(img)
x_img = tf.image.resize(x_img, (img_height, img_width))
X_misc[i] = x_img / 255.0
HBox(children=(IntProgress(value=0, description='Loading images', max=9, style=ProgressStyle(description_width…
preds_misc = model.predict(X_misc, verbose=1)
9/9 [==============================] - 0s 39ms/sample
indexs = list(range(len(X_misc)))
plot_samples(X_misc, None, preds_misc, indexs=indexs[:len(indexs)//2], prob=0.5)
plot_samples(X_misc, None, preds_misc, indexs=indexs[len(indexs)//2:], prob=0.5)
from tensorflow.keras.preprocessing.image import save_img
from tensorflow.image import resize
import os
OUTPUT_PATH = DEMO_PATH / 'predict'
if not os.path.exists(OUTPUT_PATH): os.mkdir(OUTPUT_PATH)
binary_preds = (preds_misc > 0.5).astype(np.uint8)
for index, (image, label) in enumerate((tqdm_notebook(zip(X_misc, binary_preds), total=len(X_misc), desc="Saving images and labels"))):
image_fname = OUTPUT_PATH / 'rgb_%04d.png' % index
label_fname = OUTPUT_PATH / 'label_%04d.png' % index
image = resize(image, (480, 640))
save_img(image_fname, image, scale=True)
label = resize(label, (480, 640))
save_img(label_fname, label, scale=True)
HBox(children=(IntProgress(value=0, description='Saving images and labels', max=9, style=ProgressStyle(descrip…