import warnings
warnings.simplefilter(action='ignore')
import os
from time import time
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR) # 过滤掉 Tensorflow 的 Warning 信息
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('data/', one_hot=True)
Extracting data/train-images-idx3-ubyte.gz Extracting data/train-labels-idx1-ubyte.gz Extracting data/t10k-images-idx3-ubyte.gz Extracting data/t10k-labels-idx1-ubyte.gz
print('train images shape:', mnist.train.images.shape)
print('train labels shape:', mnist.train.labels.shape)
print()
print('validation images shape:', mnist.validation.images.shape)
print('validation labels shape:', mnist.validation.labels.shape)
print()
print('test images shape:', mnist.test.images.shape)
print('test labels shape:', mnist.test.labels.shape)
train images shape: (55000, 784) train labels shape: (55000, 10) validation images shape: (5000, 784) validation labels shape: (5000, 10) test images shape: (10000, 784) test labels shape: (10000, 10)
def layer(output_dim, input_dim, inputs, activation=None):
W = tf.Variable(tf.random_normal([input_dim, output_dim]))
b = tf.Variable(tf.random_normal([1, output_dim]))
result = tf.matmul(inputs, W) + b
if activation is None:
outputs = result
else:
outputs = activation(result)
return outputs
x = tf.placeholder('float', [None, 784])
h1 = layer(output_dim=1000, input_dim=784, inputs=x, activation=tf.nn.relu) # 隐藏层的神经元增大为1000
y_predict = layer(output_dim=10, input_dim=1000, inputs=h1, activation=None)
y_label = tf.placeholder('float', [None, 10])
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_predict, labels=y_label))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss_function)
correct_prediction = tf.equal(tf.argmax(y_label, axis=1), tf.argmax(y_predict, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
train_epochs = 15
batch_size = 100
total_batch = int(mnist.train.num_examples / batch_size)
epoch_list = []
loss_list = []
acc_list = []
start_time = time()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(train_epochs):
for i in range(total_batch):
x_batch, y_batch = mnist.train.next_batch(batch_size=batch_size)
sess.run(optimizer, feed_dict={x: x_batch, y_label: y_batch})
loss, acc = sess.run([loss_function, accuracy],
feed_dict={x: mnist.validation.images, y_label: mnist.validation.labels})
epoch_list.append(epoch+1)
loss_list.append(loss)
acc_list.append(acc)
print('train epoch:', '%02d' % (epoch + 1), 'loss:', '{:.9f}'.format(loss), 'acc:', acc)
print()
print('train finished. takes', time() - start_time, 'seconds')
train epoch: 01 loss: 9.910659790 acc: 0.8862 train epoch: 02 loss: 6.398650646 acc: 0.9122 train epoch: 03 loss: 5.056443214 acc: 0.9276 train epoch: 04 loss: 4.601753712 acc: 0.9248 train epoch: 05 loss: 3.620224953 acc: 0.9388 train epoch: 06 loss: 3.622395992 acc: 0.9406 train epoch: 07 loss: 2.953754187 acc: 0.9474 train epoch: 08 loss: 2.867266178 acc: 0.948 train epoch: 09 loss: 2.808548927 acc: 0.9466 train epoch: 10 loss: 2.648730993 acc: 0.9514 train epoch: 11 loss: 2.724977016 acc: 0.9482 train epoch: 12 loss: 2.598264694 acc: 0.9518 train epoch: 13 loss: 2.520519495 acc: 0.9532 train epoch: 14 loss: 2.709043980 acc: 0.9508 train epoch: 15 loss: 2.382934809 acc: 0.9562 train finished. takes 74.44282412528992 seconds
def show_train_history(x_values, y_values, title):
plt.plot(x_values, y_values, label=title)
plt.xlabel('Epoch')
plt.ylabel(title)
plt.legend([title], loc='upper left')
plt.show()
show_train_history(epoch_list, acc_list, 'acc')
show_train_history(epoch_list, loss_list, 'loss')
print('accuracy:', sess.run(accuracy, feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))
accuracy: 0.9542
prediction_result = sess.run(tf.argmax(y_predict, axis=1), feed_dict={x: mnist.test.images})
prediction_result[:10]
array([7, 2, 1, 0, 4, 1, 4, 9, 4, 9])
def plot_images_labels_prediction(images, labels, predictions, idx, num=10):
"""
images: 数字图像数组
labels: 真实值数组
predictions: 预测结果数据
idx: 开始显示的数据index
num: 要显示的数据项数, 默认为10, 不超过25
"""
fig = plt.gcf()
fig.set_size_inches(12, 14)
if num > 25:
num = 25
for i in range(0, num):
ax = plt.subplot(5, 5, i+1)
ax.imshow(images[idx].reshape(28, 28), cmap='binary')
title = 'lable=' + str(np.argmax(labels[idx]))
if len(predictions) > 0:
title += ',predict=' + str(predictions[idx])
ax.set_title(title, fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
idx += 1
plt.show()
plot_images_labels_prediction(mnist.test.images, mnist.test.labels, prediction_result, 0, 10)
for i in range(300):
label = np.argmax(mnist.test.labels[i])
predict = prediction_result[i]
if predict != label:
print('i=' + str(i), 'label=' + str(label), 'predict=' + str(predict))
i=8 label=5 predict=4 i=38 label=2 predict=3 i=125 label=9 predict=4 i=149 label=2 predict=9 i=151 label=9 predict=8 i=241 label=9 predict=8 i=245 label=3 predict=5 i=247 label=4 predict=6 i=257 label=8 predict=1 i=259 label=6 predict=0
sess.close()