MNIST-Neural Network-Two Hidden Layers with Tensorflow - All-in-One

In [13]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
import math
import tensorflow as tf

mnist = input_data.read_data_sets("../MNIST_data/", one_hot=True)

batch_size = 100
training_epochs = 100
learning_rate = 0.05

epoch_list = []
train_error_list = []
validation_error_list = []
test_accuracy_list = []
diff_index_list = []

# Network Parameters
n_input = 784 # MNIST data input (img shape: 28*28)
n_hidden_1 = 128 # 1st layer number of features
n_hidden_2 = 128 # 2nd layer number of features
n_classes = 10 # MNIST total classes (0-9 digits)
    
# Data Preparation
x = tf.placeholder(tf.float32, [None, n_input])
y_target = tf.placeholder(tf.float32, [None, n_classes])

# Model Construction
# Store layers weight & bias
weights = {
    'W1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),
    'W2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
    'out': tf.Variable(tf.random_normal([n_hidden_2, n_classes]))
}

biases = {
    'b1': tf.Variable(tf.random_normal([n_hidden_1])),
    'b2': tf.Variable(tf.random_normal([n_hidden_2])),
    'out': tf.Variable(tf.random_normal([n_classes]))
}

# 1st Hidden layer with RELU activation
u2 = tf.matmul(x, weights['W1']) + biases['b1']
z2 = tf.nn.relu(u2)

# 2ndHidden layer with RELU activation
u3 = tf.matmul(z2, weights['W2']) + biases['b2']
z3 = tf.nn.relu(u3)

# Output layer with linear activation
u_out = tf.matmul(z3, weights['out']) + biases['out']

# Target Setup
error = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=u_out, labels=y_target))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(error)

# Accuracy   
prediction_and_ground_truth = tf.equal(tf.argmax(u_out, 1), tf.argmax(y_target, 1))
accuracy = tf.reduce_mean(tf.cast(prediction_and_ground_truth, tf.float32))

def draw_error_values_and_accuracy():
    # Draw Error Values and Accuracy
    fig = plt.figure(figsize=(20, 5))
    plt.subplot(121)
    plt.plot(epoch_list[1:], train_error_list[1:], 'r', label='Train')
    plt.plot(epoch_list[1:], validation_error_list[1:], 'g', label='Validation')
    plt.ylabel('Total Error')
    plt.xlabel('Epochs')
    plt.grid(True)
    plt.legend(loc='upper right')

    plt.subplot(122)
    plt.plot(epoch_list[1:], test_accuracy_list[1:], 'b', label='Test')
    plt.ylabel('Accuracy')
    plt.xlabel('Epochs')
    plt.yticks(np.arange(0.0, 1.0, 0.05))
    plt.grid(True)
    plt.legend(loc='lower right')
    plt.show()
    
def draw_false_prediction():
    fig = plt.figure(figsize=(20, 5))
    for i in range(5):
        j = diff_index_list[i]
        print("False Prediction Index: %s, Prediction: %s, Ground Truth: %s" % (j, prediction[j], ground_truth[j]))
        img = np.array(mnist.test.images[j])
        img.shape = (28, 28)
        plt.subplot(150 + (i+1))
        plt.imshow(img, cmap='gray')
    
with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    
    total_batch = int(math.ceil(mnist.train.num_examples/float(batch_size)))
    print("Total batch: %d" % total_batch)    

    for epoch in range(training_epochs):
        epoch_list.append(epoch)
        # Train Error Value
        train_error_value = sess.run(error, feed_dict={x: mnist.train.images, y_target: mnist.train.labels})
        train_error_list.append(train_error_value)
        
        validation_error_value = sess.run(error, feed_dict={x: mnist.validation.images, y_target: mnist.validation.labels})
        validation_error_list.append(validation_error_value)
        
        test_accuracy_value = sess.run(accuracy, feed_dict={x: mnist.test.images, y_target: mnist.test.labels})
        test_accuracy_list.append(test_accuracy_value) 
        print("Epoch: {0:2d}, Train Error: {1:0.5f}, Validation Error: {2:0.5f}, Test Accuracy: {3:0.5f}".format(epoch, train_error_value, validation_error_value, test_accuracy_value))
        
        for i in range(total_batch):
            batch_images, batch_labels = mnist.train.next_batch(batch_size)
            sess.run(optimizer, feed_dict={x: batch_images, y_target: batch_labels})
    

    # Draw Graph about Error Values & Accuracy Values
    draw_error_values_and_accuracy()
    
    # False Prediction Profile
    prediction = sess.run(tf.argmax(u_out, 1), feed_dict={x:mnist.test.images})
    ground_truth = sess.run(tf.argmax(y_target, 1), feed_dict={y_target:mnist.test.labels})

    print(prediction)
    print(ground_truth)

    for i in range(mnist.test.num_examples):
        if (prediction[i] != ground_truth[i]):
            diff_index_list.append(i)
            
    print("Number of False Prediction:", len(diff_index_list))
    draw_false_prediction()
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Total batch: 550
Epoch:  0, Train Error: 812.29523, Validation Error: 808.12689, Test Accuracy: 0.07790
Epoch:  1, Train Error: 1.79483, Validation Error: 1.84282, Test Accuracy: 0.50580
Epoch:  2, Train Error: 1.25159, Validation Error: 1.29967, Test Accuracy: 0.59770
Epoch:  3, Train Error: 1.05341, Validation Error: 1.12201, Test Accuracy: 0.65020
Epoch:  4, Train Error: 0.94076, Validation Error: 1.02594, Test Accuracy: 0.71630
Epoch:  5, Train Error: 0.82792, Validation Error: 0.89349, Test Accuracy: 0.74610
Epoch:  6, Train Error: 0.79810, Validation Error: 0.86793, Test Accuracy: 0.75910
Epoch:  7, Train Error: 0.84932, Validation Error: 0.91554, Test Accuracy: 0.75460
Epoch:  8, Train Error: 0.66226, Validation Error: 0.73082, Test Accuracy: 0.81400
Epoch:  9, Train Error: 0.61916, Validation Error: 0.69496, Test Accuracy: 0.81790
Epoch: 10, Train Error: 0.57571, Validation Error: 0.65443, Test Accuracy: 0.83680
Epoch: 11, Train Error: 0.59305, Validation Error: 0.67415, Test Accuracy: 0.83550
Epoch: 12, Train Error: 0.53465, Validation Error: 0.61159, Test Accuracy: 0.84810
Epoch: 13, Train Error: 0.50453, Validation Error: 0.57334, Test Accuracy: 0.85340
Epoch: 14, Train Error: 0.52441, Validation Error: 0.60729, Test Accuracy: 0.84840
Epoch: 15, Train Error: 0.46773, Validation Error: 0.54659, Test Accuracy: 0.86400
Epoch: 16, Train Error: 0.46526, Validation Error: 0.53466, Test Accuracy: 0.86760
Epoch: 17, Train Error: 0.48935, Validation Error: 0.57058, Test Accuracy: 0.86380
Epoch: 18, Train Error: 0.48435, Validation Error: 0.55283, Test Accuracy: 0.85410
Epoch: 19, Train Error: 0.41527, Validation Error: 0.49609, Test Accuracy: 0.87720
Epoch: 20, Train Error: 0.40882, Validation Error: 0.48294, Test Accuracy: 0.87770
Epoch: 21, Train Error: 0.42988, Validation Error: 0.51087, Test Accuracy: 0.87480
Epoch: 22, Train Error: 0.39772, Validation Error: 0.46901, Test Accuracy: 0.87940
Epoch: 23, Train Error: 0.41309, Validation Error: 0.49045, Test Accuracy: 0.87580
Epoch: 24, Train Error: 0.37628, Validation Error: 0.44864, Test Accuracy: 0.88630
Epoch: 25, Train Error: 0.39727, Validation Error: 0.46964, Test Accuracy: 0.87960
Epoch: 26, Train Error: 0.38017, Validation Error: 0.45607, Test Accuracy: 0.88850
Epoch: 27, Train Error: 0.38129, Validation Error: 0.45064, Test Accuracy: 0.88420
Epoch: 28, Train Error: 0.35800, Validation Error: 0.43231, Test Accuracy: 0.89520
Epoch: 29, Train Error: 0.36063, Validation Error: 0.43067, Test Accuracy: 0.89040
Epoch: 30, Train Error: 0.35008, Validation Error: 0.44063, Test Accuracy: 0.89580
Epoch: 31, Train Error: 0.32868, Validation Error: 0.41023, Test Accuracy: 0.89680
Epoch: 32, Train Error: 0.34583, Validation Error: 0.41745, Test Accuracy: 0.89250
Epoch: 33, Train Error: 0.34128, Validation Error: 0.42384, Test Accuracy: 0.89550
Epoch: 34, Train Error: 0.34642, Validation Error: 0.43611, Test Accuracy: 0.89150
Epoch: 35, Train Error: 0.31529, Validation Error: 0.39924, Test Accuracy: 0.90340
Epoch: 36, Train Error: 0.32012, Validation Error: 0.40372, Test Accuracy: 0.90070
Epoch: 37, Train Error: 0.31602, Validation Error: 0.40089, Test Accuracy: 0.90020
Epoch: 38, Train Error: 0.29867, Validation Error: 0.38319, Test Accuracy: 0.90540
Epoch: 39, Train Error: 0.31006, Validation Error: 0.39565, Test Accuracy: 0.90120
Epoch: 40, Train Error: 0.31541, Validation Error: 0.40505, Test Accuracy: 0.89980
Epoch: 41, Train Error: 0.29779, Validation Error: 0.38346, Test Accuracy: 0.90540
Epoch: 42, Train Error: 0.28780, Validation Error: 0.37945, Test Accuracy: 0.90760
Epoch: 43, Train Error: 0.28159, Validation Error: 0.36461, Test Accuracy: 0.90790
Epoch: 44, Train Error: 0.30380, Validation Error: 0.39420, Test Accuracy: 0.90660
Epoch: 45, Train Error: 0.30246, Validation Error: 0.39710, Test Accuracy: 0.90310
Epoch: 46, Train Error: 0.28050, Validation Error: 0.36612, Test Accuracy: 0.90940
Epoch: 47, Train Error: 0.27489, Validation Error: 0.36211, Test Accuracy: 0.90990
Epoch: 48, Train Error: 0.27424, Validation Error: 0.36144, Test Accuracy: 0.90920
Epoch: 49, Train Error: 0.27405, Validation Error: 0.36339, Test Accuracy: 0.91060
[7 2 1 ..., 4 5 6]
[7 2 1 ..., 4 5 6]
Number of False Prediction: 911
False Prediction Index: 8, Prediction: 2, Ground Truth: 5
False Prediction Index: 33, Prediction: 6, Ground Truth: 4
False Prediction Index: 38, Prediction: 3, Ground Truth: 2
False Prediction Index: 61, Prediction: 2, Ground Truth: 8
False Prediction Index: 62, Prediction: 4, Ground Truth: 9
In [ ]: