# adapted from fast.ai p1 v3 joshfp José Fernández Portal 'Share your work here'
%matplotlib inline
import torch
import matplotlib.pyplot as plt
import numpy as np
n = 100
x = torch.ones(n, 2)
x[:,0].uniform_(-1., 1); x[:5]
tensor([[-0.2581, 1.0000], [ 0.9924, 1.0000], [ 0.7040, 1.0000], [ 0.1151, 1.0000], [ 0.3407, 1.0000]])
w_y = torch.tensor([3., 2]); w_y
tensor([3., 2.])
y = x@w_y + torch.rand(n)
plt.scatter(x[:,0], y);
def mse(y_hat, y): return ((y_hat-y)**2).mean()
from mpl_toolkits import mplot3d
def loss_wrt_wgts(w1, w2):
w = torch.Tensor([w1, w2])
y_hat = x@w
return mse(y_hat, y)
loss_wgts = np.vectorize(loss_wrt_wgts)
mesh = np.meshgrid(np.linspace(-20, 20, 50), np.linspace(-20, 20, 50))
loss_mesh = loss_wgts(*mesh)
fig = plt.figure(figsize=(10,10))
ax = plt.axes(projection='3d')
ax.plot_surface(*mesh, loss_mesh, cmap='rainbow', alpha=0.8)
ax.set_xlabel('w0'); ax.set_ylabel('w1'); ax.set_zlabel('Loss')
ax.view_init(20, 20)
def update(w):
y_hat = x@w
loss = mse(y, y_hat)
prev_wgts = w.data.clone()
loss.backward()
with torch.no_grad():
w -= lr * w.grad
w.grad.zero_()
return prev_wgts, loss.item()
lr = 0.5 #choose learning rate, will run while loss is greater than 0.1
# for example, depending on random data
# lr 0.01 798 epochs / lr 0.1 79 epochs / lr 0.70 9 epochs /
# lr 0.9 36 epochs / lr 0.91 43 epochs
init_param = [-18., -18]
recorder = [(torch.tensor(init_param), 1.)]
w = torch.tensor(init_param, requires_grad=True)
while (recorder[-1][-1]>0.1): recorder.append(update(w))
rec_wgts, rec_loss = [list(o) for o in zip(*recorder[1:])]
rec_wgts = torch.stack(rec_wgts)
n_epochs=len(rec_loss)
from matplotlib import animation
from matplotlib.gridspec import GridSpec
plt.rc('animation', html='html5')
fig = plt.figure(figsize=(8, 6))
gs = GridSpec(2, 2, width_ratios=[1, 2.5])
# plot data points & model
ax0 = fig.add_subplot(gs[0,0])
ax0.scatter(x[:,0], y, c='orange', label='Data')
ax0.set_title('Data & Model', fontsize=16)
line0, = ax0.plot([], [], label='Model')
ax0.legend(loc='lower right')
# plot 3d loss
ax1 = fig.add_subplot(gs[:,1], projection='3d')
ax1.set_title('Loss', fontsize=16, pad=20)
ax1.plot_surface(*mesh, loss_mesh, cmap='rainbow', alpha=0.8)
ax1.plot3D([w_y[0]], [w_y[1]], [0], c='r', marker='x', markersize=10,
label='Global minimum', linewidth=0)
line1, = ax1.plot3D([], [], [], c='r', marker='o', alpha=0.4, label='Loss')
ax1.set_xlabel('w0'); ax1.set_ylabel('w1'); ax1.set_zlabel('Loss')
ax1.view_init(30, 20)
ax1.legend()
# plot weights & loss
ax2 = fig.add_subplot(gs[1,0])
ax2.set_title('Weights & Loss', fontsize=16)
line2, = ax2.plot([],[], label='w0')
line3, = ax2.plot([],[], label='w1')
ax2.set_ylim(-20, 5)
ax2.set_xlim(0, n_epochs)
ax2.set_xlabel('epochs')
ax2.set_ylabel('weights')
ax3 = ax2.twinx()
line4, = ax3.plot([],[], label='loss', c='r')
ax3.set_ylabel('loss')
ax3.set_ybound(0, 500)
ax2.legend((line2, line3, line4), ('w0', 'w1', 'loss'), loc='center right')
ttl = fig.suptitle(f'lr: {lr} - Epoch: 0/{n_epochs}', fontsize=22)
fig.tight_layout()
fig.subplots_adjust(top=0.85)
plt.close()
def animate(i):
line0.set_data(x[:,0].numpy(), (x@rec_wgts[i]).numpy())
line1.set_data(rec_wgts[:i+1,0].numpy(), rec_wgts[:i+1,1].numpy())
line1.set_3d_properties(rec_loss[:i+1])
epochs = np.arange(i+1)
line2.set_data(epochs, rec_wgts[:i+1, 0].numpy())
line3.set_data(epochs, rec_wgts[:i+1, 1].numpy())
line4.set_data(epochs, rec_loss[:i+1])
ttl.set_text(f'lr: {lr} - Epoch: {i+1}/{n_epochs}')
return line0, line1, line2, line3, line4, ttl
animation.FuncAnimation(fig, animate, range(n_epochs), interval=100)