import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
# Define model and Loss
class Model(object):
def __init__(self):
self.W = tf.Variable(10.0)
self.b = tf.Variable(-5.0)
def __call__(self, inputs):
return self.W * inputs + self.b
def compute_loss(y_true, y_pred):
return tf.reduce_mean(tf.square(y_true-y_pred))
model = Model()
# Define True weight and bias
TRUE_W = 3.0
TRUE_b = 2.0
# Obtain training data, Let's synthesize the training data with some noise.
NUM_EXAMPLES = 1000
inputs = tf.random.normal(shape=[NUM_EXAMPLES])
noise = tf.random.normal(shape=[NUM_EXAMPLES])
outputs = inputs * TRUE_W + TRUE_b + noise
# Before we train the model let's visualize where the model stands right now.
# We'll plot the model's predictions in red and the training data in blue.
def plot(epoch):
plt.scatter(inputs, outputs, c='b')
plt.scatter(inputs, model(inputs), c='r')
plt.title("epoch %2d, loss = %s" %(epoch, str(compute_loss(outputs, model(inputs)).numpy())))
plt.legend()
plt.draw()
plt.ion()
plt.pause(1)
plt.close()
# Define a training loop
learning_rate = 0.1
for epoch in range(30):
with tf.GradientTape() as tape:
loss = compute_loss(outputs, model(inputs))
dW, db = tape.gradient(loss, [model.W, model.b])
model.W.assign_sub(learning_rate * dW)
model.b.assign_sub(learning_rate * db)
print("=> epoch %2d: w_true= %.2f, w_pred= %.2f; b_true= %.2f, b_pred= %.2f, loss= %.2f" %(
epoch+1, TRUE_W, model.W.numpy(), TRUE_b, model.b.numpy(), loss.numpy()))
if (epoch) % 10 == 0: plot(epoch + 1)
WARNING: Logging before flag parsing goes to stderr. W0308 18:57:55.500271 4321252224 legend.py:1289] No handles with labels found to put in legend.
=> epoch 1: w_true= 3.00, w_pred= 8.67; b_true= 2.00, b_pred= -3.62, loss= 96.15
W0308 18:57:56.937305 4321252224 legend.py:1289] No handles with labels found to put in legend.
=> epoch 2: w_true= 3.00, w_pred= 7.59; b_true= 2.00, b_pred= -2.51, loss= 62.87 => epoch 3: w_true= 3.00, w_pred= 6.71; b_true= 2.00, b_pred= -1.62, loss= 41.23 => epoch 4: w_true= 3.00, w_pred= 6.00; b_true= 2.00, b_pred= -0.91, loss= 27.16 => epoch 5: w_true= 3.00, w_pred= 5.43; b_true= 2.00, b_pred= -0.33, loss= 18.02 => epoch 6: w_true= 3.00, w_pred= 4.96; b_true= 2.00, b_pred= 0.13, loss= 12.07 => epoch 7: w_true= 3.00, w_pred= 4.58; b_true= 2.00, b_pred= 0.49, loss= 8.20 => epoch 8: w_true= 3.00, w_pred= 4.28; b_true= 2.00, b_pred= 0.79, loss= 5.68 => epoch 9: w_true= 3.00, w_pred= 4.03; b_true= 2.00, b_pred= 1.03, loss= 4.05 => epoch 10: w_true= 3.00, w_pred= 3.83; b_true= 2.00, b_pred= 1.22, loss= 2.98 => epoch 11: w_true= 3.00, w_pred= 3.66; b_true= 2.00, b_pred= 1.37, loss= 2.29
W0308 18:57:58.325464 4321252224 legend.py:1289] No handles with labels found to put in legend.
=> epoch 12: w_true= 3.00, w_pred= 3.53; b_true= 2.00, b_pred= 1.49, loss= 1.84 => epoch 13: w_true= 3.00, w_pred= 3.42; b_true= 2.00, b_pred= 1.59, loss= 1.55 => epoch 14: w_true= 3.00, w_pred= 3.34; b_true= 2.00, b_pred= 1.67, loss= 1.36 => epoch 15: w_true= 3.00, w_pred= 3.27; b_true= 2.00, b_pred= 1.73, loss= 1.23 => epoch 16: w_true= 3.00, w_pred= 3.21; b_true= 2.00, b_pred= 1.78, loss= 1.15 => epoch 17: w_true= 3.00, w_pred= 3.16; b_true= 2.00, b_pred= 1.82, loss= 1.10 => epoch 18: w_true= 3.00, w_pred= 3.12; b_true= 2.00, b_pred= 1.85, loss= 1.06 => epoch 19: w_true= 3.00, w_pred= 3.09; b_true= 2.00, b_pred= 1.88, loss= 1.04 => epoch 20: w_true= 3.00, w_pred= 3.07; b_true= 2.00, b_pred= 1.90, loss= 1.03 => epoch 21: w_true= 3.00, w_pred= 3.05; b_true= 2.00, b_pred= 1.92, loss= 1.02
=> epoch 22: w_true= 3.00, w_pred= 3.03; b_true= 2.00, b_pred= 1.93, loss= 1.01 => epoch 23: w_true= 3.00, w_pred= 3.02; b_true= 2.00, b_pred= 1.94, loss= 1.01 => epoch 24: w_true= 3.00, w_pred= 3.01; b_true= 2.00, b_pred= 1.95, loss= 1.00 => epoch 25: w_true= 3.00, w_pred= 3.00; b_true= 2.00, b_pred= 1.96, loss= 1.00 => epoch 26: w_true= 3.00, w_pred= 2.99; b_true= 2.00, b_pred= 1.96, loss= 1.00 => epoch 27: w_true= 3.00, w_pred= 2.99; b_true= 2.00, b_pred= 1.97, loss= 1.00 => epoch 28: w_true= 3.00, w_pred= 2.98; b_true= 2.00, b_pred= 1.97, loss= 1.00 => epoch 29: w_true= 3.00, w_pred= 2.98; b_true= 2.00, b_pred= 1.97, loss= 1.00 => epoch 30: w_true= 3.00, w_pred= 2.98; b_true= 2.00, b_pred= 1.98, loss= 1.00