import warnings
warnings.simplefilter(action='ignore')
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
print('x_train:', x_train.shape)
print('y_train:', y_train.shape)
print('x_test:', x_test.shape)
print('y_test:', y_test.shape)
x_train: (50000, 32, 32, 3) y_train: (50000, 1) x_test: (10000, 32, 32, 3) y_test: (10000, 1)
x_train_normalize = x_train.astype('float32') / 255.0
x_test_normalize = x_test.astype('float32') / 255.0
y_train_one_hot = tf.keras.utils.to_categorical(y_train)
y_test_one_hot = tf.keras.utils.to_categorical(y_test)
y_test_one_hot.shape
(10000, 10)
model = tf.keras.models.Sequential([
# 卷积层1与池化层1
tf.keras.layers.Conv2D(input_shape=(32, 32, 3), filters=32, kernel_size=(3, 3), padding='same', activation='relu'),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
# 卷积层2与池化层2
tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu'),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
# 卷积层3与池化层3
tf.keras.layers.Conv2D(filters=128, kernel_size=(3, 3), padding='same', activation='relu'),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Conv2D(filters=128, kernel_size=(3, 3), padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
# 平坦层
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.3),
# 隐藏层1(2500个神经元)
tf.keras.layers.Dense(2500, activation='relu'),
tf.keras.layers.Dropout(0.3),
# 隐藏层2(1500个神经元)
tf.keras.layers.Dense(1500, activation='relu'),
tf.keras.layers.Dropout(0.3),
# 输出层
tf.keras.layers.Dense(10, activation='softmax')
])
print(model.summary())
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 32, 32, 32) 896 _________________________________________________________________ dropout (Dropout) (None, 32, 32, 32) 0 _________________________________________________________________ conv2d_1 (Conv2D) (None, 32, 32, 32) 9248 _________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 16, 16, 32) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 16, 16, 64) 18496 _________________________________________________________________ dropout_1 (Dropout) (None, 16, 16, 64) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 16, 16, 64) 36928 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 8, 8, 64) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 8, 8, 128) 73856 _________________________________________________________________ dropout_2 (Dropout) (None, 8, 8, 128) 0 _________________________________________________________________ conv2d_5 (Conv2D) (None, 8, 8, 128) 147584 _________________________________________________________________ max_pooling2d_2 (MaxPooling2 (None, 4, 4, 128) 0 _________________________________________________________________ flatten (Flatten) (None, 2048) 0 _________________________________________________________________ dropout_3 (Dropout) (None, 2048) 0 _________________________________________________________________ dense (Dense) (None, 2500) 5122500 _________________________________________________________________ dropout_4 (Dropout) (None, 2500) 0 _________________________________________________________________ dense_1 (Dense) (None, 1500) 3751500 _________________________________________________________________ dropout_5 (Dropout) (None, 1500) 0 _________________________________________________________________ dense_2 (Dense) (None, 10) 15010 ================================================================= Total params: 9,176,018 Trainable params: 9,176,018 Non-trainable params: 0 _________________________________________________________________ None
try:
model.load_weights('save_model/model_cifar10_cnn_deeper.h5')
print('加载模型成功! 继续训练模型.')
except:
print('加载模型失败, 开始训练一个新模型')
加载模型失败, 开始训练一个新模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
train_history = model.fit(x=x_train_normalize, y=y_train_one_hot, validation_split=0.2,
epochs=10, batch_size=128, verbose=2)
Train on 40000 samples, validate on 10000 samples Epoch 1/10 - 327s - loss: 1.4048 - acc: 0.4818 - val_loss: 1.3247 - val_acc: 0.5224 Epoch 3/10 - 330s - loss: 1.2062 - acc: 0.5643 - val_loss: 1.1510 - val_acc: 0.5969 Epoch 4/10 - 331s - loss: 1.0788 - acc: 0.6155 - val_loss: 0.9656 - val_acc: 0.6586 Epoch 5/10 - 330s - loss: 0.9780 - acc: 0.6520 - val_loss: 0.9960 - val_acc: 0.6565 Epoch 6/10 - 326s - loss: 0.9038 - acc: 0.6809 - val_loss: 0.8209 - val_acc: 0.7134 Epoch 7/10 - 329s - loss: 0.8324 - acc: 0.7054 - val_loss: 0.7892 - val_acc: 0.7240 Epoch 8/10 - 328s - loss: 0.7756 - acc: 0.7239 - val_loss: 0.7598 - val_acc: 0.7374 Epoch 9/10 - 327s - loss: 0.7320 - acc: 0.7431 - val_loss: 0.7315 - val_acc: 0.7497 Epoch 10/10 - 327s - loss: 0.6898 - acc: 0.7553 - val_loss: 0.7132 - val_acc: 0.7483
def show_train_history(train_history, train, validation):
plt.plot(train_history.history[train])
plt.plot(train_history.history[validation])
plt.title('Train History')
plt.xlabel('Epoch')
plt.ylabel(train)
plt.legend(['train', 'validation'], loc='upper left')
plt.show()
show_train_history(train_history, 'acc', 'val_acc')
show_train_history(train_history, 'loss', 'val_loss')
scores = model.evaluate(x_test_normalize, y_test_one_hot)
print()
print('accuracy:', scores[1])
10000/10000 [==============================] - 21s 2ms/step accuracy: 0.7447
predictions = model.predict_classes(x_test_normalize)
predictions[:10]
array([8, 8, 8, 8, 6, 6, 1, 6, 3, 9])
label_dict = {
0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck'
}
def plot_images_labels_prediction(images, labels, predictions, idx, num=10):
"""
images: 图像数组
labels: 真实值数组, 其实每个元素的数字代表一种图像类别的名称
predictions: 预测结果数据
idx: 开始显示的数据index
num: 要显示的数据项数, 默认为10, 不超过25
"""
fig = plt.gcf()
fig.set_size_inches(12, 14)
if num > 25:
num = 25
for i in range(0, num):
ax = plt.subplot(5, 5, i+1)
ax.imshow(images[idx], cmap='binary')
title = str(labels[idx][0]) + ',' + label_dict[labels[idx][0]]
if len(predictions) > 0:
title += '=>' + label_dict[predictions[idx]]
ax.set_title(title, fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
idx += 1
plt.show()
plot_images_labels_prediction(x_test, y_test, predictions, idx=0, num=10)
predicted_probability = model.predict(x_test_normalize)
def show_predicted_probability(x, y, predictions, predicted_probability, idx):
print('label:', label_dict[y[idx][0]])
print('predict:', label_dict[predictions[idx]])
plt.figure(figsize=(2, 2))
plt.imshow(x[idx].reshape((32, 32, 3)))
plt.show()
for j in range(10):
print(label_dict[j], ':', predicted_probability[idx][j])
show_predicted_probability(x_test, y_test, predictions, predicted_probability, 0)
label: cat predict: ship
airplane : 0.005173354 automobile : 0.015833795 bird : 0.020151181 cat : 0.3291398 deer : 0.0040378636 dog : 0.1168231 frog : 0.13499218 horse : 0.0046175853 ship : 0.3520366 truck : 0.017194541
show_predicted_probability(x_test, y_test, predictions, predicted_probability, 3)
label: airplane predict: ship
airplane : 0.35164014 automobile : 0.030136198 bird : 0.00475265 cat : 0.0015198193 deer : 0.0004779164 dog : 0.000117743686 frog : 0.00021989802 horse : 0.00014916781 ship : 0.57337856 truck : 0.037607934
print('predictons shape:', predictions.shape)
predictons shape: (10000,)
print('y_test shape:', y_test.shape)
print(y_test)
y_test shape: (10000, 1) [[3] [8] [8] ... [5] [1] [7]]
y_test.reshape(-1)
array([3, 8, 8, ..., 5, 1, 7])
pd.crosstab(y_test.reshape(-1), predictions, rownames=['label'], colnames=['predict'])
predict | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
---|---|---|---|---|---|---|---|---|---|---|
label | ||||||||||
0 | 662 | 21 | 46 | 19 | 12 | 3 | 14 | 3 | 168 | 52 |
1 | 4 | 907 | 1 | 4 | 2 | 2 | 4 | 0 | 26 | 50 |
2 | 47 | 7 | 621 | 63 | 75 | 52 | 95 | 11 | 21 | 8 |
3 | 14 | 13 | 53 | 575 | 47 | 138 | 93 | 16 | 23 | 28 |
4 | 13 | 8 | 55 | 75 | 705 | 24 | 64 | 37 | 13 | 6 |
5 | 5 | 9 | 46 | 194 | 35 | 615 | 40 | 27 | 16 | 13 |
6 | 2 | 10 | 25 | 41 | 19 | 15 | 868 | 5 | 13 | 2 |
7 | 13 | 9 | 30 | 57 | 60 | 59 | 9 | 723 | 6 | 34 |
8 | 20 | 22 | 5 | 8 | 1 | 1 | 4 | 3 | 920 | 16 |
9 | 13 | 90 | 3 | 7 | 2 | 3 | 7 | 4 | 20 | 851 |
print(label_dict)
{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
model_json = model.to_json()
with open('save_model/model_cifar10_cnn_deeper.json', 'w') as json_file:
json_file.write(model_json)
model_yaml = model.to_yaml()
with open('save_model/model_cifar10_cnn_deeper.yaml', 'w') as yaml_file:
yaml_file.write(model_yaml)
model.save_weights('save_model/model_cifar10_cnn_deeper.h5', save_format='h5')
print('模型已保存到:', 'save_model/model_cifar10_cnn_deeper.h5')
模型已保存到: save_model/model_cifar10_cnn_deeper.h5