In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

Load necessary libraries

In [10]:
data_x = np.linspace(1, 8, 100)[:, np.newaxis]
data_y = np.polyval([1, -14, 59, -70], data_x) \
        + 1.5 * np.sin(data_x) + np.random.randn(100, 1)

Generate our data

In [3]:
model_order = 5
data_x = np.power(data_x, range(model_order))
data_x /= np.max(data_x, axis=0)

Add intercept data and normalize

In [4]:
order = np.random.permutation(len(data_x))
portion = 20
test_x = data_x[order[:portion]]
test_y = data_y[order[:portion]]
train_x = data_x[order[portion:]]
train_y = data_y[order[portion:]]

Shuffle data and produce train and test sets

In [5]:
init_param = lambda shape: tf.zeros(shape, dtype=tf.float32)

with tf.name_scope("IO"):
    inputs = tf.placeholder(tf.float32, [None, model_order], name="X")
    outputs = tf.placeholder(tf.float32, [None, 1], name="Yhat")

with tf.name_scope("LR"):
    W = tf.Variable(init_param([model_order, 1]), name="W")
    y = tf.matmul(inputs, W)
    
with tf.name_scope("train"):
    learning_rate = tf.Variable(0.5, trainable=False)
    cost_op = tf.reduce_mean(tf.pow(y-outputs, 2))
    train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost_op)

Create TensorFlow graph

In [6]:
tolerance = 1e-3

# Perform Stochastic Gradient Descent
epochs = 1
last_cost = 0
alpha = 0.4
max_epochs = 50000

sess = tf.Session()
print "Beginning Training"
with sess.as_default():
    init = tf.initialize_all_variables()
    sess.run(init)
    sess.run(tf.assign(learning_rate, alpha))
    writer = tf.train.SummaryWriter("/tmp/tboard", sess.graph) # Create TensorBoard files
    while True:

        sess.run(train_op, feed_dict={inputs: train_x, outputs: train_y})
            
        # Keep track of our performance
        if epochs%100==0:
            cost = sess.run(cost_op, feed_dict={inputs: train_x, outputs: train_y})
            print "Epoch: %d - Error: %.4f" %(epochs, cost)

            # Stopping Condition
            if abs(last_cost - cost) < tolerance or epochs > max_epochs:
                print "Converged."
                break
            last_cost = cost
            
        epochs += 1
    
    w = W.eval()
    print "w =", w
    print "Test Cost =", sess.run(cost_op, feed_dict={inputs: test_x, outputs: test_y})
Beginning Training
Epoch: 100 - Error: 46.5217
Epoch: 200 - Error: 45.4410
Epoch: 300 - Error: 44.4665
Epoch: 400 - Error: 43.5364
Epoch: 500 - Error: 42.6334
Epoch: 600 - Error: 41.7520
Epoch: 700 - Error: 40.8904
Epoch: 800 - Error: 40.0480
Epoch: 900 - Error: 39.2242
Epoch: 1000 - Error: 38.4185
Epoch: 1100 - Error: 37.6305
Epoch: 1200 - Error: 36.8598
Epoch: 1300 - Error: 36.1061
Epoch: 1400 - Error: 35.3690
Epoch: 1500 - Error: 34.6481
Epoch: 1600 - Error: 33.9430
Epoch: 1700 - Error: 33.2535
Epoch: 1800 - Error: 32.5791
Epoch: 1900 - Error: 31.9195
Epoch: 2000 - Error: 31.2745
Epoch: 2100 - Error: 30.6436
Epoch: 2200 - Error: 30.0266
Epoch: 2300 - Error: 29.4232
Epoch: 2400 - Error: 28.8331
Epoch: 2500 - Error: 28.2559
Epoch: 2600 - Error: 27.6915
Epoch: 2700 - Error: 27.1394
Epoch: 2800 - Error: 26.5995
Epoch: 2900 - Error: 26.0715
Epoch: 3000 - Error: 25.5550
Epoch: 3100 - Error: 25.0500
Epoch: 3200 - Error: 24.5560
Epoch: 3300 - Error: 24.0729
Epoch: 3400 - Error: 23.6004
Epoch: 3500 - Error: 23.1383
Epoch: 3600 - Error: 22.6864
Epoch: 3700 - Error: 22.2444
Epoch: 3800 - Error: 21.8121
Epoch: 3900 - Error: 21.3894
Epoch: 4000 - Error: 20.9759
Epoch: 4100 - Error: 20.5715
Epoch: 4200 - Error: 20.1761
Epoch: 4300 - Error: 19.7893
Epoch: 4400 - Error: 19.4110
Epoch: 4500 - Error: 19.0410
Epoch: 4600 - Error: 18.6792
Epoch: 4700 - Error: 18.3253
Epoch: 4800 - Error: 17.9792
Epoch: 4900 - Error: 17.6408
Epoch: 5000 - Error: 17.3097
Epoch: 5100 - Error: 16.9860
Epoch: 5200 - Error: 16.6693
Epoch: 5300 - Error: 16.3597
Epoch: 5400 - Error: 16.0568
Epoch: 5500 - Error: 15.7606
Epoch: 5600 - Error: 15.4709
Epoch: 5700 - Error: 15.1875
Epoch: 5800 - Error: 14.9104
Epoch: 5900 - Error: 14.6394
Epoch: 6000 - Error: 14.3743
Epoch: 6100 - Error: 14.1151
Epoch: 6200 - Error: 13.8616
Epoch: 6300 - Error: 13.6136
Epoch: 6400 - Error: 13.3711
Epoch: 6500 - Error: 13.1339
Epoch: 6600 - Error: 12.9020
Epoch: 6700 - Error: 12.6751
Epoch: 6800 - Error: 12.4532
Epoch: 6900 - Error: 12.2362
Epoch: 7000 - Error: 12.0240
Epoch: 7100 - Error: 11.8164
Epoch: 7200 - Error: 11.6134
Epoch: 7300 - Error: 11.4149
Epoch: 7400 - Error: 11.2207
Epoch: 7500 - Error: 11.0308
Epoch: 7600 - Error: 10.8450
Epoch: 7700 - Error: 10.6634
Epoch: 7800 - Error: 10.4857
Epoch: 7900 - Error: 10.3119
Epoch: 8000 - Error: 10.1420
Epoch: 8100 - Error: 9.9758
Epoch: 8200 - Error: 9.8132
Epoch: 8300 - Error: 9.6542
Epoch: 8400 - Error: 9.4987
Epoch: 8500 - Error: 9.3466
Epoch: 8600 - Error: 9.1979
Epoch: 8700 - Error: 9.0524
Epoch: 8800 - Error: 8.9102
Epoch: 8900 - Error: 8.7710
Epoch: 9000 - Error: 8.6349
Epoch: 9100 - Error: 8.5018
Epoch: 9200 - Error: 8.3716
Epoch: 9300 - Error: 8.2443
Epoch: 9400 - Error: 8.1198
Epoch: 9500 - Error: 7.9980
Epoch: 9600 - Error: 7.8789
Epoch: 9700 - Error: 7.7624
Epoch: 9800 - Error: 7.6484
Epoch: 9900 - Error: 7.5370
Epoch: 10000 - Error: 7.4280
Epoch: 10100 - Error: 7.3214
Epoch: 10200 - Error: 7.2172
Epoch: 10300 - Error: 7.1152
Epoch: 10400 - Error: 7.0155
Epoch: 10500 - Error: 6.9179
Epoch: 10600 - Error: 6.8225
Epoch: 10700 - Error: 6.7292
Epoch: 10800 - Error: 6.6380
Epoch: 10900 - Error: 6.5487
Epoch: 11000 - Error: 6.4614
Epoch: 11100 - Error: 6.3760
Epoch: 11200 - Error: 6.2925
Epoch: 11300 - Error: 6.2108
Epoch: 11400 - Error: 6.1310
Epoch: 11500 - Error: 6.0528
Epoch: 11600 - Error: 5.9764
Epoch: 11700 - Error: 5.9017
Epoch: 11800 - Error: 5.8286
Epoch: 11900 - Error: 5.7571
Epoch: 12000 - Error: 5.6871
Epoch: 12100 - Error: 5.6188
Epoch: 12200 - Error: 5.5519
Epoch: 12300 - Error: 5.4864
Epoch: 12400 - Error: 5.4224
Epoch: 12500 - Error: 5.3599
Epoch: 12600 - Error: 5.2986
Epoch: 12700 - Error: 5.2388
Epoch: 12800 - Error: 5.1802
Epoch: 12900 - Error: 5.1229
Epoch: 13000 - Error: 5.0669
Epoch: 13100 - Error: 5.0121
Epoch: 13200 - Error: 4.9585
Epoch: 13300 - Error: 4.9061
Epoch: 13400 - Error: 4.8548
Epoch: 13500 - Error: 4.8047
Epoch: 13600 - Error: 4.7556
Epoch: 13700 - Error: 4.7076
Epoch: 13800 - Error: 4.6607
Epoch: 13900 - Error: 4.6148
Epoch: 14000 - Error: 4.5699
Epoch: 14100 - Error: 4.5260
Epoch: 14200 - Error: 4.4831
Epoch: 14300 - Error: 4.4411
Epoch: 14400 - Error: 4.4000
Epoch: 14500 - Error: 4.3598
Epoch: 14600 - Error: 4.3205
Epoch: 14700 - Error: 4.2820
Epoch: 14800 - Error: 4.2444
Epoch: 14900 - Error: 4.2076
Epoch: 15000 - Error: 4.1716
Epoch: 15100 - Error: 4.1364
Epoch: 15200 - Error: 4.1020
Epoch: 15300 - Error: 4.0683
Epoch: 15400 - Error: 4.0354
Epoch: 15500 - Error: 4.0032
Epoch: 15600 - Error: 3.9717
Epoch: 15700 - Error: 3.9408
Epoch: 15800 - Error: 3.9107
Epoch: 15900 - Error: 3.8812
Epoch: 16000 - Error: 3.8523
Epoch: 16100 - Error: 3.8241
Epoch: 16200 - Error: 3.7965
Epoch: 16300 - Error: 3.7695
Epoch: 16400 - Error: 3.7431
Epoch: 16500 - Error: 3.7173
Epoch: 16600 - Error: 3.6920
Epoch: 16700 - Error: 3.6673
Epoch: 16800 - Error: 3.6431
Epoch: 16900 - Error: 3.6195
Epoch: 17000 - Error: 3.5963
Epoch: 17100 - Error: 3.5737
Epoch: 17200 - Error: 3.5515
Epoch: 17300 - Error: 3.5299
Epoch: 17400 - Error: 3.5087
Epoch: 17500 - Error: 3.4880
Epoch: 17600 - Error: 3.4677
Epoch: 17700 - Error: 3.4479
Epoch: 17800 - Error: 3.4285
Epoch: 17900 - Error: 3.4095
Epoch: 18000 - Error: 3.3909
Epoch: 18100 - Error: 3.3728
Epoch: 18200 - Error: 3.3550
Epoch: 18300 - Error: 3.3376
Epoch: 18400 - Error: 3.3206
Epoch: 18500 - Error: 3.3040
Epoch: 18600 - Error: 3.2877
Epoch: 18700 - Error: 3.2718
Epoch: 18800 - Error: 3.2562
Epoch: 18900 - Error: 3.2410
Epoch: 19000 - Error: 3.2261
Epoch: 19100 - Error: 3.2115
Epoch: 19200 - Error: 3.1972
Epoch: 19300 - Error: 3.1832
Epoch: 19400 - Error: 3.1696
Epoch: 19500 - Error: 3.1562
Epoch: 19600 - Error: 3.1432
Epoch: 19700 - Error: 3.1304
Epoch: 19800 - Error: 3.1179
Epoch: 19900 - Error: 3.1056
Epoch: 20000 - Error: 3.0936
Epoch: 20100 - Error: 3.0819
Epoch: 20200 - Error: 3.0704
Epoch: 20300 - Error: 3.0592
Epoch: 20400 - Error: 3.0483
Epoch: 20500 - Error: 3.0375
Epoch: 20600 - Error: 3.0270
Epoch: 20700 - Error: 3.0167
Epoch: 20800 - Error: 3.0067
Epoch: 20900 - Error: 2.9968
Epoch: 21000 - Error: 2.9872
Epoch: 21100 - Error: 2.9778
Epoch: 21200 - Error: 2.9685
Epoch: 21300 - Error: 2.9595
Epoch: 21400 - Error: 2.9507
Epoch: 21500 - Error: 2.9421
Epoch: 21600 - Error: 2.9336
Epoch: 21700 - Error: 2.9253
Epoch: 21800 - Error: 2.9172
Epoch: 21900 - Error: 2.9093
Epoch: 22000 - Error: 2.9015
Epoch: 22100 - Error: 2.8940
Epoch: 22200 - Error: 2.8865
Epoch: 22300 - Error: 2.8793
Epoch: 22400 - Error: 2.8721
Epoch: 22500 - Error: 2.8652
Epoch: 22600 - Error: 2.8584
Epoch: 22700 - Error: 2.8517
Epoch: 22800 - Error: 2.8452
Epoch: 22900 - Error: 2.8388
Epoch: 23000 - Error: 2.8325
Epoch: 23100 - Error: 2.8264
Epoch: 23200 - Error: 2.8204
Epoch: 23300 - Error: 2.8146
Epoch: 23400 - Error: 2.8088
Epoch: 23500 - Error: 2.8032
Epoch: 23600 - Error: 2.7977
Epoch: 23700 - Error: 2.7924
Epoch: 23800 - Error: 2.7871
Epoch: 23900 - Error: 2.7819
Epoch: 24000 - Error: 2.7769
Epoch: 24100 - Error: 2.7719
Epoch: 24200 - Error: 2.7671
Epoch: 24300 - Error: 2.7624
Epoch: 24400 - Error: 2.7577
Epoch: 24500 - Error: 2.7532
Epoch: 24600 - Error: 2.7487
Epoch: 24700 - Error: 2.7444
Epoch: 24800 - Error: 2.7401
Epoch: 24900 - Error: 2.7360
Epoch: 25000 - Error: 2.7319
Epoch: 25100 - Error: 2.7279
Epoch: 25200 - Error: 2.7240
Epoch: 25300 - Error: 2.7201
Epoch: 25400 - Error: 2.7164
Epoch: 25500 - Error: 2.7127
Epoch: 25600 - Error: 2.7091
Epoch: 25700 - Error: 2.7056
Epoch: 25800 - Error: 2.7021
Epoch: 25900 - Error: 2.6987
Epoch: 26000 - Error: 2.6954
Epoch: 26100 - Error: 2.6922
Epoch: 26200 - Error: 2.6890
Epoch: 26300 - Error: 2.6859
Epoch: 26400 - Error: 2.6828
Epoch: 26500 - Error: 2.6799
Epoch: 26600 - Error: 2.6769
Epoch: 26700 - Error: 2.6741
Epoch: 26800 - Error: 2.6713
Epoch: 26900 - Error: 2.6685
Epoch: 27000 - Error: 2.6658
Epoch: 27100 - Error: 2.6632
Epoch: 27200 - Error: 2.6606
Epoch: 27300 - Error: 2.6581
Epoch: 27400 - Error: 2.6556
Epoch: 27500 - Error: 2.6531
Epoch: 27600 - Error: 2.6508
Epoch: 27700 - Error: 2.6484
Epoch: 27800 - Error: 2.6461
Epoch: 27900 - Error: 2.6439
Epoch: 28000 - Error: 2.6417
Epoch: 28100 - Error: 2.6395
Epoch: 28200 - Error: 2.6374
Epoch: 28300 - Error: 2.6353
Epoch: 28400 - Error: 2.6333
Epoch: 28500 - Error: 2.6313
Epoch: 28600 - Error: 2.6294
Epoch: 28700 - Error: 2.6274
Epoch: 28800 - Error: 2.6256
Epoch: 28900 - Error: 2.6237
Epoch: 29000 - Error: 2.6219
Epoch: 29100 - Error: 2.6202
Epoch: 29200 - Error: 2.6184
Epoch: 29300 - Error: 2.6167
Epoch: 29400 - Error: 2.6150
Epoch: 29500 - Error: 2.6134
Epoch: 29600 - Error: 2.6118
Epoch: 29700 - Error: 2.6102
Epoch: 29800 - Error: 2.6087
Epoch: 29900 - Error: 2.6072
Epoch: 30000 - Error: 2.6057
Epoch: 30100 - Error: 2.6042
Epoch: 30200 - Error: 2.6028
Epoch: 30300 - Error: 2.6014
Epoch: 30400 - Error: 2.6000
Epoch: 30500 - Error: 2.5987
Epoch: 30600 - Error: 2.5973
Epoch: 30700 - Error: 2.5960
Epoch: 30800 - Error: 2.5947
Epoch: 30900 - Error: 2.5935
Epoch: 31000 - Error: 2.5923
Epoch: 31100 - Error: 2.5910
Epoch: 31200 - Error: 2.5899
Epoch: 31300 - Error: 2.5887
Epoch: 31400 - Error: 2.5875
Epoch: 31500 - Error: 2.5864
Epoch: 31600 - Error: 2.5853
Epoch: 31700 - Error: 2.5842
Epoch: 31800 - Error: 2.5832
Epoch: 31900 - Error: 2.5821
Epoch: 32000 - Error: 2.5811
Epoch: 32100 - Error: 2.5801
Epoch: 32200 - Error: 2.5791
Converged.
w = [[ -47.71784592]
 [ 279.96063232]
 [-335.62420654]
 [-165.66163635]
 [ 290.33401489]]
Test Cost = 2.91072

Perform gradient descent to learn model

In [7]:
y_model = np.polyval(w[::-1], np.linspace(0,1,200))
plt.plot(np.linspace(0,1,200), y_model, c='g', label='Model')
plt.scatter(train_x[:,1], train_y, c='b', label='Train Set')
plt.scatter(test_x[:,1], test_y, c='r', label='Test Set')
plt.grid()
plt.legend(loc='upper left')
plt.xlabel('X')
plt.ylabel('Y')
plt.xlim(0,1)
plt.show()

Plot the model obtained