import os
import numpy as np
import matplotlib.pyplot as plt
from numpy import pad
import tensorflow as tf
import pandas as pd
from model.Model import Model
from preprocess.preprocess import Dataset, PreprocessFrame, make_augments
import seaborn as sns; sns.set()
import warnings
warnings.filterwarnings("ignore")
# images size
img_width = 900
img_height = 120
# parameters of resized images
new_img_width = 350
new_img_height = 50
batch_size = 16
# default paths
WORKING_DIR = os.path.join('/home', 'HTR')
ann_path = os.path.join(WORKING_DIR, 'HKR_Dataset_Words_Public', 'ann')
img_path = os.path.join(WORKING_DIR, 'HKR_Dataset_Words_Public', 'img')
metadata = os.path.join(WORKING_DIR, 'metadata', 'metadata.tsv')
# collect metadata
# meta_collect(ann_path, metadata)
# get preprocessed metadata dataframe
df = PreprocessFrame(metadata=metadata,
img_height=img_height, img_width=img_width)
# Make augments file (if they exists: comment or delete line)
aug_df = None
# aug_df = make_augments(df=df, img_path=img_path, WORKING_DIR=WORKING_DIR,
# new_img_height=new_img_height, new_img_width=new_img_width)
# get augments metadata dataframe from original dataframe if not starting make_augments
if not isinstance(aug_df, pd.DataFrame):
aug_df = df.copy()
aug_df.index = aug_df.index.to_series().apply(lambda x: os.path.join('aug_1', 'aug_' + x))
train, test, val = list(Dataset(df, aug_df=aug_df,
test_size=0.1,
val_size=0.05,
img_path=img_path,
img_height=img_height,
img_width=img_width,
new_img_height=new_img_height,
new_img_width=new_img_width,
WORKING_DIR=WORKING_DIR,
shuffle=True,
random_state=12))
print(len(train), len(test), len(val))
7377 399 200
params = {
'callbacks': ['checkpoint', 'csv_log', 'tb_log', 'early_stopping'],
'metrics': ['cer', 'accuracy'],
'checkpoint_path': os.path.join(WORKING_DIR, 'checkpoints/training_2/cp.ckpt'),
'csv_log_path': os.path.join(WORKING_DIR, 'logs/csv_logs/log_2.csv'),
'tb_log_path': os.path.join(WORKING_DIR, 'logs/tb_logs/log2'),
'tb_update_freq': 200,
'epochs': 50,
'batch_size': batch_size,
'early_stopping_patience': 10,
'input_img_shape': (new_img_width, new_img_height, 1),
'vocab_len': 75,
'max_label_len': 22,
'chars_path': os.path.join(os.path.split(metadata)[0], 'symbols.txt'),
'blank': '#',
'blank_index': 74,
'corpus': os.path.join(os.path.split(metadata)[0], 'corpus.txt')
}
model = Model(params)
model.build()
model.get_summary()
#model.load_weights('checkpoints/training_2/cp.ckpt')
Model: "htr_model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== image (InputLayer) [(None, 350, 50, 1)] 0 __________________________________________________________________________________________________ Conv1 (Conv2D) (None, 350, 50, 32) 832 image[0][0] __________________________________________________________________________________________________ pool1 (MaxPooling2D) (None, 175, 25, 32) 0 Conv1[0][0] __________________________________________________________________________________________________ Conv2 (Conv2D) (None, 175, 25, 64) 18496 pool1[0][0] __________________________________________________________________________________________________ pool2 (MaxPooling2D) (None, 87, 12, 64) 0 Conv2[0][0] __________________________________________________________________________________________________ Conv3 (Conv2D) (None, 87, 12, 128) 73856 pool2[0][0] __________________________________________________________________________________________________ pool3 (MaxPooling2D) (None, 43, 6, 128) 0 Conv3[0][0] __________________________________________________________________________________________________ Conv4 (Conv2D) (None, 43, 6, 256) 131328 pool3[0][0] __________________________________________________________________________________________________ reshape (Reshape) (None, 43, 1536) 0 Conv4[0][0] __________________________________________________________________________________________________ dense1 (Dense) (None, 43, 64) 98368 reshape[0][0] __________________________________________________________________________________________________ dropout (Dropout) (None, 43, 64) 0 dense1[0][0] __________________________________________________________________________________________________ bidirectional (Bidirectional) (None, 43, 256) 197632 dropout[0][0] __________________________________________________________________________________________________ bidirectional_1 (Bidirectional) (None, 43, 128) 164352 bidirectional[0][0] __________________________________________________________________________________________________ label (InputLayer) [(None, None)] 0 __________________________________________________________________________________________________ dense2 (Dense) (None, 43, 75) 9675 bidirectional_1[0][0] __________________________________________________________________________________________________ ctc_loss (CTCLayer) (None, 43, 75) 0 label[0][0] dense2[0][0] ================================================================================================== Total params: 694,539 Trainable params: 694,539 Non-trainable params: 0 __________________________________________________________________________________________________
model.fit(train, val)
Epoch 1/50 3540/3540 [==============================] - 543s 153ms/step - loss: 24.3410 - cer: 0.6626 - accuracy: 0.0186 - val_loss: 9.6557 - val_cer: 0.2501 - val_accuracy: 0.1375 Epoch 00001: val_loss improved from inf to 9.65573, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 2/50 3540/3540 [==============================] - 539s 152ms/step - loss: 8.3436 - cer: 0.2246 - accuracy: 0.2000 - val_loss: 5.3913 - val_cer: 0.1448 - val_accuracy: 0.3519 Epoch 00002: val_loss improved from 9.65573 to 5.39129, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 3/50 3540/3540 [==============================] - 535s 151ms/step - loss: 5.7637 - cer: 0.1564 - accuracy: 0.3444 - val_loss: 4.2013 - val_cer: 0.1124 - val_accuracy: 0.4570 Epoch 00003: val_loss improved from 5.39129 to 4.20127, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 4/50 3540/3540 [==============================] - 532s 150ms/step - loss: 4.6814 - cer: 0.1273 - accuracy: 0.4304 - val_loss: 3.6643 - val_cer: 0.0980 - val_accuracy: 0.5049 Epoch 00004: val_loss improved from 4.20127 to 3.66432, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 5/50 3540/3540 [==============================] - 530s 150ms/step - loss: 4.0092 - cer: 0.1098 - accuracy: 0.4853 - val_loss: 3.2595 - val_cer: 0.0863 - val_accuracy: 0.5594 Epoch 00005: val_loss improved from 3.66432 to 3.25945, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 6/50 3540/3540 [==============================] - 530s 150ms/step - loss: 3.5419 - cer: 0.0973 - accuracy: 0.5286 - val_loss: 3.0124 - val_cer: 0.0808 - val_accuracy: 0.5824 Epoch 00006: val_loss improved from 3.25945 to 3.01236, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 7/50 3540/3540 [==============================] - 529s 149ms/step - loss: 3.2181 - cer: 0.0885 - accuracy: 0.5645 - val_loss: 2.9389 - val_cer: 0.0789 - val_accuracy: 0.5981 Epoch 00007: val_loss improved from 3.01236 to 2.93888, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 8/50 3540/3540 [==============================] - 530s 150ms/step - loss: 2.9199 - cer: 0.0803 - accuracy: 0.5926 - val_loss: 2.8247 - val_cer: 0.0759 - val_accuracy: 0.6196 Epoch 00008: val_loss improved from 2.93888 to 2.82470, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 9/50 3540/3540 [==============================] - 532s 150ms/step - loss: 2.7135 - cer: 0.0750 - accuracy: 0.6141 - val_loss: 2.5880 - val_cer: 0.0687 - val_accuracy: 0.6438 Epoch 00009: val_loss improved from 2.82470 to 2.58799, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 10/50 3540/3540 [==============================] - 531s 150ms/step - loss: 2.5273 - cer: 0.0699 - accuracy: 0.6335 - val_loss: 2.5905 - val_cer: 0.0676 - val_accuracy: 0.6549 Epoch 00010: val_loss did not improve from 2.58799 Epoch 11/50 3540/3540 [==============================] - 531s 150ms/step - loss: 2.3940 - cer: 0.0664 - accuracy: 0.6490 - val_loss: 2.5240 - val_cer: 0.0667 - val_accuracy: 0.6573 Epoch 00011: val_loss improved from 2.58799 to 2.52397, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 12/50 3540/3540 [==============================] - 530s 150ms/step - loss: 2.2442 - cer: 0.0626 - accuracy: 0.6641 - val_loss: 2.5058 - val_cer: 0.0630 - val_accuracy: 0.6720 Epoch 00012: val_loss improved from 2.52397 to 2.50584, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 13/50 3540/3540 [==============================] - 529s 149ms/step - loss: 2.1351 - cer: 0.0593 - accuracy: 0.6768 - val_loss: 2.3859 - val_cer: 0.0618 - val_accuracy: 0.6760 Epoch 00013: val_loss improved from 2.50584 to 2.38591, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 14/50 3540/3540 [==============================] - 531s 150ms/step - loss: 2.0563 - cer: 0.0574 - accuracy: 0.6849 - val_loss: 2.4305 - val_cer: 0.0617 - val_accuracy: 0.6780 Epoch 00014: val_loss did not improve from 2.38591 Epoch 15/50 3540/3540 [==============================] - 532s 150ms/step - loss: 1.9658 - cer: 0.0548 - accuracy: 0.6956 - val_loss: 2.3480 - val_cer: 0.0603 - val_accuracy: 0.6840 Epoch 00015: val_loss improved from 2.38591 to 2.34805, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 16/50 3540/3540 [==============================] - 531s 150ms/step - loss: 1.8858 - cer: 0.0532 - accuracy: 0.7044 - val_loss: 2.3585 - val_cer: 0.0573 - val_accuracy: 0.7013 Epoch 00016: val_loss did not improve from 2.34805 Epoch 17/50 3540/3540 [==============================] - 532s 150ms/step - loss: 1.8475 - cer: 0.0519 - accuracy: 0.7100 - val_loss: 2.2807 - val_cer: 0.0571 - val_accuracy: 0.7019 Epoch 00017: val_loss improved from 2.34805 to 2.28071, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 18/50 3540/3540 [==============================] - 532s 150ms/step - loss: 1.7724 - cer: 0.0501 - accuracy: 0.7201 - val_loss: 2.2580 - val_cer: 0.0563 - val_accuracy: 0.7017 Epoch 00018: val_loss improved from 2.28071 to 2.25803, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 19/50 3540/3540 [==============================] - 532s 150ms/step - loss: 1.7113 - cer: 0.0481 - accuracy: 0.7267 - val_loss: 2.2584 - val_cer: 0.0545 - val_accuracy: 0.7167 Epoch 00019: val_loss did not improve from 2.25803 Epoch 20/50 3540/3540 [==============================] - 530s 150ms/step - loss: 1.6861 - cer: 0.0476 - accuracy: 0.7302 - val_loss: 2.2924 - val_cer: 0.0580 - val_accuracy: 0.6971 Epoch 00020: val_loss did not improve from 2.25803 Epoch 21/50 3540/3540 [==============================] - 531s 150ms/step - loss: 1.6132 - cer: 0.0455 - accuracy: 0.7381 - val_loss: 2.2598 - val_cer: 0.0555 - val_accuracy: 0.7077 Epoch 00021: val_loss did not improve from 2.25803 Epoch 22/50 3540/3540 [==============================] - 531s 150ms/step - loss: 1.6214 - cer: 0.0455 - accuracy: 0.7403 - val_loss: 2.1676 - val_cer: 0.0538 - val_accuracy: 0.7201 Epoch 00022: val_loss improved from 2.25803 to 2.16758, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 23/50 3540/3540 [==============================] - 531s 150ms/step - loss: 1.5687 - cer: 0.0443 - accuracy: 0.7447 - val_loss: 2.0911 - val_cer: 0.0511 - val_accuracy: 0.7281 Epoch 00023: val_loss improved from 2.16758 to 2.09106, saving model to /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt Epoch 24/50 3540/3540 [==============================] - 531s 150ms/step - loss: 1.5066 - cer: 0.0425 - accuracy: 0.7524 - val_loss: 2.2059 - val_cer: 0.0537 - val_accuracy: 0.7257 Epoch 00024: val_loss did not improve from 2.09106 Epoch 25/50 3540/3540 [==============================] - 531s 150ms/step - loss: 1.5082 - cer: 0.0427 - accuracy: 0.7541 - val_loss: 2.2165 - val_cer: 0.0543 - val_accuracy: 0.7285 Epoch 00025: val_loss did not improve from 2.09106 Epoch 26/50 3540/3540 [==============================] - 530s 150ms/step - loss: 1.4745 - cer: 0.0419 - accuracy: 0.7570 - val_loss: 2.2537 - val_cer: 0.0552 - val_accuracy: 0.7226 Epoch 00026: val_loss did not improve from 2.09106 Epoch 27/50 3540/3540 [==============================] - 530s 150ms/step - loss: 1.4514 - cer: 0.0411 - accuracy: 0.7586 - val_loss: 2.1681 - val_cer: 0.0511 - val_accuracy: 0.7395 Epoch 00027: val_loss did not improve from 2.09106 Epoch 28/50 3540/3540 [==============================] - 532s 150ms/step - loss: 1.4182 - cer: 0.0402 - accuracy: 0.7652 - val_loss: 2.2615 - val_cer: 0.0537 - val_accuracy: 0.7210 Epoch 00028: val_loss did not improve from 2.09106 Epoch 29/50 3540/3540 [==============================] - 531s 150ms/step - loss: 1.4128 - cer: 0.0401 - accuracy: 0.7639 - val_loss: 2.2255 - val_cer: 0.0518 - val_accuracy: 0.7295 Epoch 00029: val_loss did not improve from 2.09106 Epoch 30/50 3540/3540 [==============================] - 531s 150ms/step - loss: 1.4127 - cer: 0.0404 - accuracy: 0.7640 - val_loss: 2.1086 - val_cer: 0.0520 - val_accuracy: 0.7324 Epoch 00030: val_loss did not improve from 2.09106 Epoch 31/50 3540/3540 [==============================] - 530s 150ms/step - loss: 1.3425 - cer: 0.0383 - accuracy: 0.7719 - val_loss: 2.1460 - val_cer: 0.0507 - val_accuracy: 0.7326 Epoch 00031: val_loss did not improve from 2.09106 Epoch 32/50 3540/3540 [==============================] - 531s 150ms/step - loss: 1.3325 - cer: 0.0378 - accuracy: 0.7751 - val_loss: 2.2742 - val_cer: 0.0554 - val_accuracy: 0.7208 Epoch 00032: val_loss did not improve from 2.09106 Epoch 33/50 3540/3540 [==============================] - 532s 150ms/step - loss: 1.3074 - cer: 0.0372 - accuracy: 0.7772 - val_loss: 2.3394 - val_cer: 0.0535 - val_accuracy: 0.7324 Epoch 00033: val_loss did not improve from 2.09106 model weights saved at /workspace/pybooks/HTR/checkpoints/training_2/cp.ckpt
def show_history(history, indicators, nrows=2, ncols=2, width=15, height=10):
_, ax = plt.subplots(nrows, ncols, figsize=(width, height))
for i in range(min(len(indicators), nrows*ncols)):
indicator_name = indicators[i]
ax[i // ncols, i % ncols].plot(history[indicator_name])
ax[i // ncols, i % ncols].plot(history['val_' + indicator_name])
ax[i // ncols, i % ncols].set_title('Model ' + indicator_name)
ax[i // ncols, i % ncols].set_ylabel(indicator_name)
ax[i // ncols, i % ncols].set_xlabel('epoch')
ax[i // ncols, i % ncols].legend(['Train', 'Val'], loc='lower left')
plt.show()
# import json
# with open('history_2.json', 'r') as f:
# h = json.load(f)
history = model.get_history()
# import json
# with open('history_2.json', 'w') as f:
# json.dump(history, f)
show_history(history, ['loss', 'cer', 'accuracy'])
test_metrics = model.evaluate(test)
399/399 [==============================] - 12s 30ms/step - loss: 2.1444 - cer: 0.0525 - accuracy: 0.7179
print(f'Метрики на тестовых данных:\n\
ctc loss: {round(test_metrics[0], 2)}\n\
CER: {round(test_metrics[1] * 100, 2)}%\n\
accuracy: {round(test_metrics[2] * 100, 2)}%\n\
')
Метрики на тестовых данных: ctc loss: 2.14 CER: 5.25% accuracy: 71.79%
def show_preds_on_batch(model, batch, batch_size=batch_size):
_, ax = plt.subplots(batch_size, 1, figsize=(10, 50))
batch_images = batch['image']
batch_labels = batch['label']
pred_texts = model.predict(batch)
orig_texts = []
for label in batch_labels:
label = tf.strings.reduce_join(model.num_to_char(label)).numpy().decode('utf-8').replace('#', '')
orig_texts.append(label)
for i in range(batch_size):
img = ((batch_images[i, :, :, 0] + 0.5) * 255).numpy().astype('uint8')
img = img.T
title = f"Prediction: {pred_texts[i]}\nOriginal: {orig_texts[i]}"
ax[i].set_title(title)
ax[i].grid(False)
ax[i].imshow(img, cmap="gray")
for batch in test.take(1):
model.evaluate(batch)
show_preds_on_batch(model, batch)
1/1 [==============================] - 0s 52ms/step - loss: 2.0155 - cer: 0.0680 - accuracy: 0.8125
Попробуем посмотреть, на каких буквах чаще всего ошибается модель. Для этого построим confusion matrix. Так как в общем случае трудно разобрать, где модель перепутала символ, а где распознала как два или вообще не распознала - выберем лишь те пары ответ/ожидание, которые одинаковы по длине.
def rm_spaces(s):
return s.replace(' ', '')
y_pred = np.empty((0, ))
y_true = np.empty((0, ))
for batch in test:
y_pred = np.append(y_pred, model.predict(batch))
for label in batch['label']:
label = tf.strings.reduce_join(model.num_to_char(label)).numpy().decode('utf-8')\
.replace('#', '').replace(' ', '')
y_true = np.append(y_true, label)
y_pred = np.vectorize(rm_spaces)(y_pred)
y_comp = np.stack((y_pred, y_true), axis=-1)
y_comp
array([['Лакик', 'Макат'], ['крупу', 'крупу'], ['Киргизия', 'Киргизия'], ..., ['Глушьиснег.', 'Глушьиснег.'], ['РК', 'РК'], ['длянассадится', 'длянассадится']], dtype='<U32')
vec_len = np.vectorize(len)
confusions = y_comp[(y_pred != y_true) & (vec_len(y_pred) == vec_len(y_true))]
confusions
array([['Лакик', 'Макат'], ['начлего', 'ночлега'], ['Итам,гдесердцу', 'Итам,гдесердце'], ..., ['водканианисе,', 'водканаанисе,'], ['Чтооказалось', 'Чтооказалась'], ['Сзтзнениеписаиле', 'сознаниябессилья']], dtype='<U32')
alphabet = set()
for pr, tr in confusions:
for c1, c2 in zip(pr, tr):
if c1 != c2:
alphabet.add(c1)
alphabet.add(c2)
alphabet = sorted(alphabet)
char_to_i = {c: i for i, c in enumerate(alphabet)}
conf_matrix = np.zeros((len(alphabet), len(alphabet)), dtype=int)
for pr, tr in confusions:
for c1, c2 in zip(pr, tr):
if c1 != c2:
conf_matrix[char_to_i[c1]][char_to_i[c2]] += 1
conf_matrix[char_to_i[c2]][char_to_i[c1]] += 1
_, ax = plt.subplots(figsize=(16, 16))
conf_matrix[conf_matrix.nonzero()] += 10 # for contrast
sns.heatmap(conf_matrix, xticklabels=alphabet, yticklabels=alphabet, ax=ax, \
cbar=False, cmap='gray_r', linewidths=.5, linecolor='#f0f0f0')
plt.show()
Как видно из матрицы, достаточно часто модель путает большие буквы с маленькими, что в целом не настолько важная задача, учитывая то, что сама буква в конечном счете определяется верно. А судя по выведенным массивам пар, это чаще происходит тогда, когда все слово состоит из больших букв. Можно также заметить частую путаницу между о и а, о и е, л и н и другими буквами, которые схожи по написанию. Это ожидаемый результат, тк в некоторых записях это и человеку трудно отличить. В данном случае эту проблему сможет почти полностью решить постпроцессинг результата - поиск по словарю и/или оценка контекста.