In [ ]:
%matplotlib inline
from fastai.basics import *


In this part of the lecture we explain Stochastic Gradient Descent (SGD) which is an optimization method commonly used in neural networks. We will illustrate the concepts with concrete examples.

# Linear Regression problem¶

The goal of linear regression is to fit a line to a set of points.

In [ ]:
n=100

In [ ]:
x = torch.ones(n,2)
x[:,0].uniform_(-1.,1)
x[:5]

Out[ ]:
tensor([[-0.1957,  1.0000],
[ 0.1826,  1.0000],
[-0.1008,  1.0000],
[-0.1449,  1.0000],
[ 0.7091,  1.0000]])
In [ ]:
a = tensor(3.,2); a

Out[ ]:
tensor([3., 2.])
In [ ]:
y = x@a + 0.25*torch.randn(n)

In [ ]:
plt.scatter(x[:,0], y);


You want to find parameters (weights) a such that you minimize the error between the points and the line [email protected]. Note that here a is unknown. For a regression problem the most common error function or loss function is the mean squared error.

In [ ]:
def mse(y_hat, y): return ((y_hat-y)**2).mean()


Suppose we believe a = (-1.0,1.0) then we can compute y_hat which is our prediction and then compute our error.

In [ ]:
a = tensor(-1.,1)

In [ ]:
y_hat = x@a
mse(y_hat, y)

Out[ ]:
tensor(7.9356)
In [ ]:
plt.scatter(x[:,0],y)
plt.scatter(x[:,0],y_hat);


So far we have specified the model (linear regression) and the evaluation criteria (or loss function). Now we need to handle optimization; that is, how do we find the best values for a? How do we find the best fitting linear regression.

We would like to find the values of a that minimize mse_loss.

Gradient descent is an algorithm that minimizes functions. Given a function defined by a set of parameters, gradient descent starts with an initial set of parameter values and iteratively moves toward a set of parameter values that minimize the function. This iterative minimization is achieved by taking steps in the negative direction of the function gradient.

Here is gradient descent implemented in PyTorch.

In [ ]:
a = nn.Parameter(a); a

Out[ ]:
Parameter containing:
tensor([-1.,  1.], requires_grad=True)
In [ ]:
def update():
y_hat = x@a
loss = mse(y, y_hat)
if t % 10 == 0: print(loss)
loss.backward()

In [ ]:
lr = 1e-1
for t in range(100): update()

tensor(7.9356, grad_fn=<MeanBackward1>)

In [ ]:
plt.scatter(x[:,0],y)
plt.scatter(x[:,0],x@a.detach());


## Animate it!¶

In [ ]:
from matplotlib import animation, rc
rc('animation', html='jshtml')

In [ ]:
a = nn.Parameter(tensor(-1.,1))

fig = plt.figure()
plt.scatter(x[:,0], y, c='orange')
line, = plt.plot(x[:,0], x@a.detach())
plt.close()

def animate(i):
update()
line.set_ydata(x@a.detach())
return line,

animation.FuncAnimation(fig, animate, np.arange(0, 100), interval=20)

Out[ ]:

Once Loop Reflect