MNIST-Neural Network-Single Hidden Layer with Tensorflow – All-in-One

In [4]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
import math
import tensorflow as tf

mnist = input_data.read_data_sets("../MNIST_data/", one_hot=True)

batch_size = 100
training_epochs = 100
learning_rate = 0.05

epoch_list = []
train_error_list = []
validation_error_list = []
test_accuracy_list = []
diff_index_list = []
    
# Data Preparation
x = tf.placeholder(tf.float32, [None, 784])
y_target = tf.placeholder(tf.float32, [None, 10])

# Model Construction
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
u = tf.matmul(x, W) + b

# Target Setup
error = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=u, labels=y_target))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(error)

# Accuracy   
prediction_and_ground_truth = tf.equal(tf.argmax(u, 1), tf.argmax(y_target, 1))
accuracy = tf.reduce_mean(tf.cast(prediction_and_ground_truth, tf.float32))

def draw_error_values_and_accuracy():
    # Draw Error Values and Accuracy
    fig = plt.figure(figsize=(20, 5))
    plt.subplot(121)
    plt.plot(epoch_list[1:], train_error_list[1:], 'r', label='Train')
    plt.plot(epoch_list[1:], validation_error_list[1:], 'g', label='Validation')
    plt.ylabel('Total Error')
    plt.xlabel('Epochs')
    plt.grid(True)
    plt.legend(loc='upper right')

    plt.subplot(122)
    plt.plot(epoch_list[1:], test_accuracy_list[1:], 'b', label='Test')
    plt.ylabel('Accuracy')
    plt.xlabel('Epochs')
    plt.yticks(np.arange(0.0, 1.0, 0.05))
    plt.grid(True)
    plt.legend(loc='lower right')
    plt.show()
    
def draw_false_prediction():
    fig = plt.figure(figsize=(20, 5))
    for i in range(5):
        j = diff_index_list[i]
        print("False Prediction Index: %s, Prediction: %s, Ground Truth: %s" % (j, prediction[j], ground_truth[j]))
        img = np.array(mnist.test.images[j])
        img.shape = (28, 28)
        plt.subplot(150 + (i+1))
        plt.imshow(img, cmap='gray')
    
with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    
    total_batch = int(math.ceil(mnist.train.num_examples/float(batch_size)))
    print("Total batch: %d" % total_batch)    

    for epoch in range(training_epochs):
        epoch_list.append(epoch)
        # Train Error Value
        train_error_value = sess.run(error, feed_dict={x: mnist.train.images, y_target: mnist.train.labels})
        train_error_list.append(train_error_value)
        
        validation_error_value = sess.run(error, feed_dict={x: mnist.validation.images, y_target: mnist.validation.labels})
        validation_error_list.append(validation_error_value)
        
        test_accuracy_value = sess.run(accuracy, feed_dict={x: mnist.test.images, y_target: mnist.test.labels})
        test_accuracy_list.append(test_accuracy_value) 
        print("Epoch: {0:2d}, Train Error: {1:0.5f}, Validation Error: {2:0.5f}, Test Accuracy: {3:0.5f}".format(epoch, train_error_value, validation_error_value, test_accuracy_value))
        
        for i in range(total_batch):
            batch_images, batch_labels = mnist.train.next_batch(batch_size)
            sess.run(optimizer, feed_dict={x: batch_images, y_target: batch_labels})
    

    # Draw Graph about Error Values & Accuracy Values
    draw_error_values_and_accuracy()
    
    # False Prediction Profile
    prediction = sess.run(tf.argmax(u, 1), feed_dict={x:mnist.test.images})
    ground_truth = sess.run(tf.argmax(y_target, 1), feed_dict={y_target:mnist.test.labels})

    print(prediction)
    print(ground_truth)

    for i in range(mnist.test.num_examples):
        if (prediction[i] != ground_truth[i]):
            diff_index_list.append(i)
            
    print("Number of False Prediction:", len(diff_index_list))
    draw_false_prediction()
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Total batch: 550
Epoch:  0, Train Error: 2.30271, Validation Error: 2.30257, Test Accuracy: 0.09800
Epoch:  1, Train Error: 0.45268, Validation Error: 0.43121, Test Accuracy: 0.88920
Epoch:  2, Train Error: 0.38823, Validation Error: 0.36660, Test Accuracy: 0.90370
Epoch:  3, Train Error: 0.36133, Validation Error: 0.34099, Test Accuracy: 0.90890
Epoch:  4, Train Error: 0.34381, Validation Error: 0.32437, Test Accuracy: 0.91200
Epoch:  5, Train Error: 0.33246, Validation Error: 0.31375, Test Accuracy: 0.91320
Epoch:  6, Train Error: 0.32394, Validation Error: 0.30603, Test Accuracy: 0.91500
Epoch:  7, Train Error: 0.31746, Validation Error: 0.30059, Test Accuracy: 0.91560
Epoch:  8, Train Error: 0.31214, Validation Error: 0.29624, Test Accuracy: 0.91680
Epoch:  9, Train Error: 0.30817, Validation Error: 0.29277, Test Accuracy: 0.91710
Epoch: 10, Train Error: 0.30382, Validation Error: 0.28908, Test Accuracy: 0.91810
Epoch: 11, Train Error: 0.30079, Validation Error: 0.28666, Test Accuracy: 0.91890
Epoch: 12, Train Error: 0.29812, Validation Error: 0.28494, Test Accuracy: 0.91930
Epoch: 13, Train Error: 0.29524, Validation Error: 0.28253, Test Accuracy: 0.91990
Epoch: 14, Train Error: 0.29342, Validation Error: 0.28064, Test Accuracy: 0.92040
Epoch: 15, Train Error: 0.29139, Validation Error: 0.28033, Test Accuracy: 0.91950
Epoch: 16, Train Error: 0.28926, Validation Error: 0.27790, Test Accuracy: 0.92110
Epoch: 17, Train Error: 0.28726, Validation Error: 0.27691, Test Accuracy: 0.92160
Epoch: 18, Train Error: 0.28580, Validation Error: 0.27549, Test Accuracy: 0.92180
Epoch: 19, Train Error: 0.28480, Validation Error: 0.27497, Test Accuracy: 0.92250
Epoch: 20, Train Error: 0.28287, Validation Error: 0.27353, Test Accuracy: 0.92150
Epoch: 21, Train Error: 0.28197, Validation Error: 0.27298, Test Accuracy: 0.92180
Epoch: 22, Train Error: 0.28047, Validation Error: 0.27159, Test Accuracy: 0.92200
Epoch: 23, Train Error: 0.27936, Validation Error: 0.27101, Test Accuracy: 0.92240
Epoch: 24, Train Error: 0.27886, Validation Error: 0.27110, Test Accuracy: 0.92220
Epoch: 25, Train Error: 0.27783, Validation Error: 0.26968, Test Accuracy: 0.92190
Epoch: 26, Train Error: 0.27652, Validation Error: 0.26936, Test Accuracy: 0.92240
Epoch: 27, Train Error: 0.27559, Validation Error: 0.26898, Test Accuracy: 0.92280
Epoch: 28, Train Error: 0.27555, Validation Error: 0.26977, Test Accuracy: 0.92370
Epoch: 29, Train Error: 0.27399, Validation Error: 0.26819, Test Accuracy: 0.92290
Epoch: 30, Train Error: 0.27316, Validation Error: 0.26754, Test Accuracy: 0.92280
Epoch: 31, Train Error: 0.27262, Validation Error: 0.26733, Test Accuracy: 0.92300
Epoch: 32, Train Error: 0.27181, Validation Error: 0.26704, Test Accuracy: 0.92160
Epoch: 33, Train Error: 0.27137, Validation Error: 0.26673, Test Accuracy: 0.92340
Epoch: 34, Train Error: 0.27055, Validation Error: 0.26663, Test Accuracy: 0.92330
Epoch: 35, Train Error: 0.26984, Validation Error: 0.26597, Test Accuracy: 0.92330
Epoch: 36, Train Error: 0.26916, Validation Error: 0.26527, Test Accuracy: 0.92360
Epoch: 37, Train Error: 0.26878, Validation Error: 0.26520, Test Accuracy: 0.92290
Epoch: 38, Train Error: 0.26841, Validation Error: 0.26433, Test Accuracy: 0.92340
Epoch: 39, Train Error: 0.26784, Validation Error: 0.26436, Test Accuracy: 0.92350
Epoch: 40, Train Error: 0.26712, Validation Error: 0.26463, Test Accuracy: 0.92360
Epoch: 41, Train Error: 0.26702, Validation Error: 0.26505, Test Accuracy: 0.92390
Epoch: 42, Train Error: 0.26610, Validation Error: 0.26391, Test Accuracy: 0.92260
Epoch: 43, Train Error: 0.26556, Validation Error: 0.26336, Test Accuracy: 0.92330
Epoch: 44, Train Error: 0.26537, Validation Error: 0.26327, Test Accuracy: 0.92340
Epoch: 45, Train Error: 0.26465, Validation Error: 0.26280, Test Accuracy: 0.92330
Epoch: 46, Train Error: 0.26447, Validation Error: 0.26348, Test Accuracy: 0.92480
Epoch: 47, Train Error: 0.26404, Validation Error: 0.26349, Test Accuracy: 0.92390
Epoch: 48, Train Error: 0.26354, Validation Error: 0.26289, Test Accuracy: 0.92330
Epoch: 49, Train Error: 0.26311, Validation Error: 0.26254, Test Accuracy: 0.92360
[7 2 1 ..., 4 5 6]
[7 2 1 ..., 4 5 6]
Number of False Prediction: 759
False Prediction Index: 8, Prediction: 6, Ground Truth: 5
False Prediction Index: 33, Prediction: 6, Ground Truth: 4
False Prediction Index: 63, Prediction: 2, Ground Truth: 3
False Prediction Index: 124, Prediction: 4, Ground Truth: 7
False Prediction Index: 149, Prediction: 9, Ground Truth: 2
In [ ]: