In this chapter, we will:
Tensors
.Autograd
system.[Arthur C. Clarke] "Whether we are based on carbon or silicon makes no fundamental difference; we should each be treated with appropriate respect".
Now we're going to transition into using a framework, because the network you'll be training next, Long Short-term memory neural networks, are very complex, and numpy
code describing their implementation is difficult to read, use, or debug because gradients will be flying everywhere.
It's exactly this code complexity that deep learning frameworks were created to mitigate. Especially if we want to train our model on a GPU (10-100x faster training). A through understanding of a deep learning framework will be essential on our journey toward becoming a user or researcher in deep learning. But we won't jump into any deep learning framework we've heard of, that would stifle our ability to learn about what complex models (such as LSTMs) are doing under the hood. Instead, we'll build a light deep learning framework according to the latest trends in framework development. This way, we'll have no doubt about what DL frameworks do when using them for complex architectures.
Building a deep learning framework will provide a smooth transition into using actual deep learning frameworks, because we'll already be familiar with the API and the functionality underneath it. The most beneficial thing about deep learning frameworks are their ability to do automatic backpropagation & optimization. These features let us specify only the forward propagation logic, and it handles the rest.
We should recall that a matrix is a list of vectors, and that a vector is a list of scalars. Based on this, a tensor is the abstract version of scalars, vectors, matrices, and any type of array. So:
The beginning of a new deep learning framework is the definition of this new type: the Tensor
. Let's implement it:
import numpy as np
class Tensor(object):
def __init__(self, data):
# Storing Tensor Information in `self.data` as a NumPy Array
self.data = np.array(data)
def __add__(self, other):
return Tensor(self.data + other.data)
def __sub__(self, other):
return Tensor(self.data - other.data)
def __mul__(self, other):
return Tensor(np.matmul(self.data, other.data))
def __repr__(self):
return str(self.data.__repr__())
def __str__(self):
return str(self.data.__str__())
Btw, what's the difference between __repr__()
and __str__()
in Python?
__repr__()
's goal is to be unambiguous. It is invoked when simply inspecting the object on the console.__str__()
goal is to be readable. It is invoked when print(object).x = Tensor([1,2,3,4,5])
x # invoking __repr__()
array([1, 2, 3, 4, 5])
print(x) # invoking __str__()
[1 2 3 4 5]
y = x + x
print(y)
[ 2 4 6 8 10]
This is the first version of this basic data structure. We should note that it stores all the numerical information in a NumPy
Array (self.data
) and supports element-wise operations. Adding more operations is relatively simple: we create more functions on the tensor class with the appropriate functionality.
Previously, we've computed derivatives by hand for each network we trained. Recall that this is done by moving backwards through the neural network:
.. and so on until all weights in the architecture have corrent gradients. This logic for computing gradients can also be added to the tensor object:
import numpy as np
class Tensor(object):
def __init__(self, data, creators=None, creation_op=None):
self.data = np.array(data)
self.creation_op = creation_op
self.creators = creators
self.grad = None
def __add__(self, other):
return Tensor(self.data + other.data,
creators=[self,other],
creation_op="+")
def __sub__(self, other):
return Tensor(self.data - other.data,
creators=[self, other],
creation_op="-")
def __mul__(self, other):
return Tensor(self.data*other.data,
creators=[self, other],
creation_op="*")
def __repr__(self):
return str(self.data.__repr__())
def __str__(self):
return str(self.data.__str__())
def backward(self, grad):
self.grad = grad
if (self.creation_op == "+"):
self.creators[0].backward(grad)
self.creators[1].backward(grad)
elif (self.creation_op == "-"):
self.creators[0].backward(grad)
self.creators[1].backward(Tensor([-1]) * grad)
Now, all Tensors have Gradients. Let's experiment with the new functionalities:
a = Tensor([1,2,3])
print(a.creators, a.creation_op, a.grad)
None None None
b = Tensor([4,5,6])
c = a + b
print(c.creators, " | ", c.creation_op, " | ", c.grad)
[array([1, 2, 3]), array([4, 5, 6])] | + | None
When we .backward()
on a tensor resulting from an addition of 2 Tensors, we should assign the same given grad to its parents. This is because the backward operation doesn't actually calculate the gradient at the other node, but calculates the gradient with respect to the node at the back.
x = Tensor([1,2,3,4,5])
y = Tensor([2,2,2,2,2])
z = x + y
Now we calculate ∇x and ∇y given that we have ∇z:
We note that since we are dealing with the +
operator, we have ∂z∂x=1 and ∂z∂y=1.
z.backward([1,1,1,1,1])
x.grad, y.grad
([1, 1, 1, 1, 1], [1, 1, 1, 1, 1])
Each Tensor Gets 3 new Attributes:
+
, -
, *
, ..)Performing z = x + y
creates a computation graph, with 3 nodes (x
, y
, & z
) and 2 edges (z ->
x, & z -> y
). Each edge is labeled by the creation_op
add
. This graph allows us to recursively backpropagate gradients.
The first new concept of this implementation is the automatic creation of graphs whenever we perform operations. If we took z
and performed further operations, the graph will continue to be constructed. The second new concept introduced is the automatic recursive gradient calculations that will allow us to calculate the derivative of the original tensor with respect to any of its corresponding connected nodes.
Perhaps the most elegent part of this form of autograd
is that it works recursively as well, because each node calls .backward()
on all of its self.creators
.
# Tensor Definition
a = Tensor([1,2,3,4,5])
b = Tensor([2,2,2,2,2])
c = Tensor([5,4,3,2,1])
d = Tensor([-1,-2,-3,-4,-5])
# Compute Graph Creation & Forward Propagation
e = a + b
f = c + d
g = e + f
# Back Propagation
g.backward(Tensor([1,1,1,1,1]))
a.grad
array([1, 1, 1, 1, 1])
The previous implementation is nothing new compared with what we've already been working with. Previously, we've hard-coded the forward and backpropagation steps, now it's time to automate and generalize these processes.
The Notion of a graph that gets built during the forward propagation task is called a Dynamic computation graph because it's built on the fly during the forward prop step. This is the type of autograd present in newer deep learning frameworks such as DyNet and PyTorch. Older frameworks such as Theano and TensorFlow have what's called a Static Computation Graph, which is specified before forward propagation even begins.
In general, dynamic computation graphs are easier to architect and experiment with, and static computation graphs are faster at runtime because of some fancy logic under the hood. We should note that dynamic and static based frameworks have been moving towards the middle. In this Book, We'll Stick with Dynamic graphs.
Debugging these frameworks can be extremely difficult at times, because most bugs don't raise an error, the model seems like it's training, but It's not.
The current version of Tensor
supports backpropagating into a variable only once, but sometimes during forward propagation, we'll use the same tensor multiple times and thus multiple parts of the graph will back propagate gradients into the same Tensor.
Here is an example:
a = Tensor([1,2,3,4,5])
b = Tensor([2,2,2,2,2])
c = Tensor([5,4,3,2,1])
d = a + b
e = b + e
f = d + e
f.backward(Tensor([1,1,1,1,1]))
print(b.grad.data == [2,2,2,2,2])
[False False False False False]
The variable b
is used twice in the creation of Tensor
f
. Its gradient should be the sum of the two derivatives: [2,2,2,2,2]
.
We need to fix the current implementation of our Tensor
to not merely overwrite gradients of the previous nodes.
Let's remember how gradients flow through a simple example:
%load_ext line_profiler
import numpy as np
class Tensor(object):
def __init__(self, data, autograd=False, parents=None, creation_op=None, id=None):
self.data = np.array(data)
self.autograd = autograd
self.parents = parents
self.children = {}
self.creation_op = creation_op
self.grad = None
if (id is None): id = np.random.randint(100)
self.id = id
# Updates Parents' Children
# For your library, don't do this until you need it.
# A simpler solution is to just keep a counter for the number of grads that must be passed while doing back propagation
if (parents is not None):
for parent in parents:
if (self.id not in parent.children):
# 1 is the number of grads to be passed from self to parent
parent.children[self] = 1
else:
parent.children[self] += 1
def all_grads_propagated(self):
for _, grads_count in self.children.items():
if (grads_count != 0): return False
return True
def __add__(self, other):
if (self.autograd and other.autograd):
return Tensor(self.data + other.data,
autograd=True,
parents=[self, other],
creation_op="+")
return Tensor(self.data + other.data)
def __repr__(self):
return str('Tensor(' + self.id.__repr__() + ')')
def __str__(self):
return str(self.data.__str__())
def backward(self, grad=None, grad_origin=None):
if (self.autograd):
if (grad_origin is not None):
# checks to make sure you can backpropagate or whether you're waiting for a gradient, in which case, decrement the counter
if (self.children[grad_origin] == 0):
raise Exception("cannot backprop more than once")
else:
self.children[grad_origin] -= 1
if (self.grad is None):
self.grad = grad
else:
self.grad += grad
if ((self.parents is not None) and (self.all_grads_propagated() or grad_origin is None)):
if (self.creation_op == "+"):
# begins actual back propagation
self.parents[0].backward(self.grad, grad_origin=self)
self.parents[1].backward(self.grad, grad_origin=self)
a = Tensor([1,2,3,4,5], autograd=True)
b = Tensor([2,2,2,2,2], autograd=True)
c = Tensor([5,4,3,2,1], autograd=True)
d = a + b
e = b + c
f = d + e
f.backward(Tensor([1,1,1,1,1]))
f.grad.data
array([1, 1, 1, 1, 1])
e.grad.data, d.grad.data
(array([1, 1, 1, 1, 1]), array([1, 1, 1, 1, 1]))
a.grad.data, b.grad.data, c.grad.data
(array([1, 1, 1, 1, 1]), array([2, 2, 2, 2, 2]), array([1, 1, 1, 1, 1]))
Let's implement a Tensor without using the book's quirky multiple "connections/grads" logic, we'll just use a simple counter to make sure all gradients are backpropagated before passing the local grad to the parent nodes:
import numpy as np
class Tensor(object):
def __init__(self, data, autograd=False, parents=list(), creation_op=None, id=None):
self.data = np.array(data)
# make all tensors 2-D+ for matrice multiplication
if (len(self.data.shape)==0): self.data.resize((1, 1))
elif (len(self.data.shape)==1): self.data.resize((1, self.data.shape[0]))
self.autograd = autograd
self.parents = parents
self.children = list()
self.creation_op = creation_op
self.grad = None
if (self.autograd): self.required_grads = 0
if id is None: id = np.random.randint(100)
self.id = id
# when this object is created, assign him as child to his parents & increment required_grads
if (parents != []):
for parent in parents:
parent.children.append(self)
if (self.autograd and parent.autograd):
parent.required_grads += 1
def __add__(self, other):
if (self.autograd and other.autograd):
return Tensor(self.data+other.data,
autograd=True,
parents=[self, other],
creation_op='+')
return Tensor(self.data+other.data)
def __repr__(self):
return str('Tensor(' + self.id.__repr__() + ')')
def __str__(self):
return str(self.data.__str__())
def backward(self, grad=None, grad_origin=None):
if (self.autograd):
if (self.grad is None):
self.grad = grad
else:
self.grad += grad
self.required_grads -= 1
if ((self.parents != []) and (self.required_grads==0) or (grad_origin==None)):
if (self.creation_op == '+'):
self.parents[0].backward(self.grad, grad_origin=self)
self.parents[1].backward(self.grad, grad_origin=self)
a = Tensor([1,2,3,4,5], autograd=True)
b = Tensor([2,2,2,2,2], autograd=True)
c = Tensor([5,4,3,2,1], autograd=True)
a, b, c
(Tensor(48), Tensor(95), Tensor(94))
d = a + b
e = b + c
d, d.parents, d.required_grads
(Tensor(86), [Tensor(48), Tensor(95)], 0)
a, a.children, a.required_grads
(Tensor(48), [Tensor(86)], 1)
b, b.children, b.required_grads
(Tensor(95), [Tensor(86), Tensor(60)], 2)
f = d + e
f.backward(Tensor([1,1,1,1,1]))
b.grad.data
array([[2, 2, 2, 2, 2]])
Let's go back to the book's Tensor Implementation:
We create a self.children
counter that counts the number of gradients received from each child during back propagation. This way, we also prevent a variable from accidentally backpropagating from the same child twice (which throws an exception).
Previously, whenever we called .backward()
, the object calls .backward()
on its parents. But in this case, we want the child to first receive all of its gradients before backpropagating them to its parents. None of these concepts are new from a deep learning theory perspective. These are engineering challenges that deep learning frameworks seek to face. We'll face them when debugging deep learning neural networks in a standard framework.
We can now add support for arbitrary operations by ..
.backward()
method.Backpropagation is skipped if the variable has autograd
turned off.
import numpy as np
class Tensor(object):
def __init__(self, data, autograd=False, parents=None, creation_op=None, id=None):
self.data = np.array(data)
self.autograd = autograd
self.parents = parents
self.children = {}
self.creation_op = creation_op
self.grad = None
if (id is None): id = np.random.randint(100)
self.id = id
# Updates Parents' Children
# For your library, don't do this until you need it.
# A simpler solution is to just keep a counter for the number of grads that must be passed while doing back propagation
if (parents is not None):
for parent in parents:
if (self.id not in parent.children):
# 1 is the number of grads to be passed from self to parent
parent.children[self] = 1
else:
parent.children[self] += 1
def all_grads_propagated(self):
for _, grads_count in self.children.items():
if (grads_count != 0): return False
return True
def __add__(self, other):
if (self.autograd and other.autograd):
return Tensor(self.data + other.data,
autograd=True,
parents=[self, other],
creation_op="+")
return Tensor(self.data + other.data)
def __neg__(self):
if (self.autograd):
return Tensor(self.data * -1,
autograd=True,
parents=[self],
creation_op="neg")
return Tensor(self.data * -1)
def __repr__(self):
return str('Tensor(' + self.id.__repr__() + ')')
def __str__(self):
return str(self.data.__str__())
def backward(self, grad=None, grad_origin=None):
if (self.autograd):
if (grad_origin is not None):
# checks to make sure you can backpropagate or whether you're waiting for a gradient, in which case, decrement the counter
if (self.children[grad_origin] == 0):
raise Exception("cannot backprop more than once")
else:
self.children[grad_origin] -= 1
if (self.grad is None):
self.grad = grad
else:
self.grad += grad
if ((self.parents is not None) and (self.all_grads_propagated() or grad_origin is None)):
if (self.creation_op == "+"):
self.parents[0].backward(self.grad, grad_origin=self)
self.parents[1].backward(self.grad, grad_origin=self)
if (self.creation_op == "neg"):
self.parents[0].backward(self.grad.__neg__())
a = Tensor([1,2,3,4,5], autograd=True)
b = Tensor([2,2,2,2,2], autograd=True)
c = Tensor([5,4,3,2,1], autograd=True)
d = a + (-b)
e = (-b) + c
f = d + e
f.backward(Tensor([1,1,1,1,1]))
b.grad.data
array([-2, -2, -2, -2, -2])
Let's add some more:
import numpy as np
class Tensor(object):
def __init__(self, data, autograd=False, parents=None, creation_op=None, id=None):
self.data = np.array(data)
self.autograd = autograd
self.parents = parents
self.children = {}
self.creation_op = creation_op
self.grad = None
if (id is None): id = np.random.randint(100)
self.id = id
if (parents is not None):
for parent in parents:
if (self.id not in parent.children):
parent.children[self] = 1
else:
parent.children[self] += 1
def all_grads_propagated(self):
for _, grads_count in self.children.items():
if (grads_count != 0): return False
return True
def __add__(self, other):
if (self.autograd and other.autograd):
return Tensor(self.data + other.data,
autograd=True,
parents=[self, other],
creation_op="+")
return Tensor(self.data + other.data)
def __sub__(self, other):
if (self.autograd and other.autograd):
return Tensor(self.data-other.data,
autograd=True,
parents=[self, other],
creation_op="-")
return Tensor(self.data-other.data)
def __mul__(self, other):
if (self.autograd and other.autograd):
return Tensor(self.data * other.data,
autograd=True,
parents=[self, other],
creation_op="*")
return Tensor(self.data * other.data)
def sum(self, dim):
if (self.autograd):
return Tensor(self.data.sum(dim),
autograd=True,
parents=[self],
creation_op="sum_" + str(dim))
return Tensor(self.data.sum(dim))
def __neg__(self):
if (self.autograd):
return Tensor(self.data * -1,
autograd=True,
parents=[self],
creation_op="neg")
return Tensor(self.data * -1)
def __repr__(self):
return str('Tensor(' + self.id.__repr__() + ')')
def __str__(self):
return str(self.data.__str__())
def expand(self, dim, copies):
trans_cmd = list(range(0, len(self.data.shape)))
trans_cmd.insert(dim, len(self.data.shape))
new_shape = list(self.data.shape) + [copies]
new_data = self.data.repeat(copies).reshape(new_shape)
new_data = new_data.transpose(trans_cmd)
if (self.autograd):
return Tensor(new_data,
autograd=True,
parents=[self],
creation_op="expand_"+str(dim))
return Tensor(new_data)
def transpose(self):
if (self.autograd):
return Tensor(self.data.transpose(),
autograd=True,
parents=[self],
creation_op="T")
return Tensor(self.data.transpose())
def mm(self, x):
if (self.autograd):
return Tensor(self.data.dot(x.data),
autograd=True,
parents=[self, x],
creation_op="mm")
return Tensor(self.data.dot(x.data))
def backward(self, grad=None, grad_origin=None):
if (self.autograd):
if (grad == None):
grad = Tensor(np.ones_like(self.data))
if (grad_origin is not None):
if (self.children[grad_origin] == 0):
raise Exception("cannot backprop more than once")
else:
self.children[grad_origin] -= 1
if (self.grad is None):
self.grad = grad
else:
self.grad += grad
if ((self.parents is not None) and (self.all_grads_propagated() or grad_origin is None)):
if (self.creation_op == "+"):
self.parents[0].backward(self.grad, grad_origin=self)
self.parents[1].backward(self.grad, grad_origin=self)
if (self.creation_op == "neg"):
self.parents[0].backward(self.grad.__neg__())
if (self.creation_op == '-'):
self.parents[0].backward(self.grad, grad_origin=self)
self.parents[1].backward(self.grad.__neg__(), grad_origin=self)
if (self.creation_op == '*'):
self.parents[0].backward(self.grad*self.parents[1], grad_origin=self)
self.parents[1].backward(self.grad*self.parents[0], grad_origin=self)
if (self.creation_op == 'mm'):
activation = self.parents[0] # usually an activation function
weights = self.parents[1] # usually a weights matrix
activation.backward(self.grad.mm(weights.transpose()))
weights.backward(self.grad.transpose().mm(activation).transpose())
if (self.creation_op == 'T'):
self.parents[0].backward(self.grad.transpose())
if ("sum" in self.creation_op):
dim = int(self.creation_op.split("_")[1])
ds = self.parents[0].data.shape[dim]
self.parents[0].backward(self.grad.expand(dim, ds))
if ("expand" in self.creation_op):
dim = int(self.creation_op.split("_")[1])
self.parents[0].backward(self.grad.sum(dim))
We should remember that sum()
removes a dimension & expand()
adds a dimension.
If we expand to the last dimension, it will copy single values along that last dimension. Each entry of the original Tensor becomes a list of entries instead. Thus, when we perform .sum(dim=1)
on a tensor with four entries in that dimension, we need to perform .expand(dim=1, copies=4)
to the gradient when we backpropagate it.
We should understand how to take derivatives of Matrix Multiplication. The starting Point:
The gradients start at the end of the network. The following figure explains how back propagation works for FC layers:
We have to forward propagate in such a way that layer_1
and layer_2
and diff
exist as variables, because we would need them later. We then have to backpropagate each gradient to its appropriate weight matrix and perform the weight update appropriately.
import numpy as np
np.random.seed(0)
x = Tensor(np.array([[0,0], [0,1], [1,0], [1,1]]), autograd=True)
y = Tensor(np.array([[0], [1], [0], [1]]), autograd=True)
weights = list()
weights.append(Tensor(np.random.rand(2,3), autograd=True))
weights.append(Tensor(np.random.rand(3,1), autograd=True))
for i in range(10): # epochs
y_hat = x.mm(weights[0]).mm(weights[1]) # predict
loss = ((y_hat - y)*(y_hat - y)).sum(0) # compare
loss.backward(Tensor(np.ones_like(loss.data))) # learn, feeding an initial gradient of 1 to the loss
for weight in weights:
weight.data -= weight.grad.data * 0.1
weight.grad.data *= 0
print(loss)
[0.58128304] [0.48988149] [0.41375111] [0.34489412] [0.28210124] [0.2254484] [0.17538853] [0.1324231] [0.09682769] [0.06849361]
We forward propagate over the loss computation graph, we then back propagate feeding an initial gradient of 1. With the fancy new autograd system, the code is much simpler.
When we have an autograd system, stochastic gradient descent becomes trivial to implement.
Let's try making it its own class as well:
class SGD(object):
def __init__(self, parameters, alpha):
self.parameters = parameters
self.alpha = alpha
def zero(self):
for p in self.parameters:
p.grad.data *= 0
def step(self, zero=True):
for p in self.parameters:
p.data -= p.grad.data * self.alpha
if (zero):
p.grad.data *= 0
The previous neural network is further simplified as follows, with exactly the same results as before:
import numpy as np
np.random.seed(0)
x = Tensor(np.array([[0,0], [0,1], [1,0], [1,1]]), autograd=True)
y = Tensor(np.array([[0], [1], [0], [1]]), autograd=True)
weights = list()
weights.append(Tensor(np.random.rand(2,3), autograd=True))
weights.append(Tensor(np.random.rand(3,1), autograd=True))
optimizer = SGD(weights, 0.1)
for i in range(10): # epochs
y_hat = x.mm(weights[0]).mm(weights[1]) # forward propagation
loss = ((y_hat - y)*(y_hat - y)).sum(0) # compare
loss.backward(Tensor(np.ones_like(loss.data))) # back propagation, feeding an initial gradient of 1 to the loss
optimizer.step() # learn
print(loss)
[0.58128304] [0.48988149] [0.41375111] [0.34489412] [0.28210124] [0.2254484] [0.17538853] [0.1324231] [0.09682769] [0.06849361]
Probably the most common abstraction among all deep learning framework abstraction is the layer abstraction. It's a collection of commonly used forward propagation techniques packaged into a simple API with some kind of forward method to call them.
Here is an example of a simple Linear Layer:
class Layer(object):
def __init__(self):
self.parameters = list()
def get_parameters(self):
return self.parameters
class Linear(Layer):
def __init__(self, n_inputs, n_outputs):
super().__init__()
W = np.random.randn(n_inputs, n_outputs)*np.sqrt(2.0/n_inputs)
self.weight = Tensor(W, autograd=True)
self.bias = Tensor(np.zeros(n_outputs), autograd=True)
self.parameters.append(self.weight)
self.parameters.append(self.bias)
def forward(self, input):
# expand for broadcasting
return input.mm(self.weight)+self.bias.expand(0,len(input.data))
The weights are organized into a class, and we need to add a bias matrix because this is a true Linear layer. We can initialize the layers all together, such that the weights and biases are initialized in the correct sizes & the correct forward propagation logic is always employed.
We created an abstract Layer
class which will allow for more complicated layers (example: layers the contain other layers).
class Sequential(Layer):
def __init__(self, layers=list()):
super().__init__()
self.layers = layers
def add(self, layer):
self.layers.append(layer)
def forward(self, input):
for layer in self.layers:
input = layer.forward(input)
return input
def get_parameters(self):
params = list()
for layer in self.layers:
params += layer.get_parameters()
return params
x = Tensor(np.array([[0, 0], [0,1], [1,0], [1,1]]), autograd=True)
y = Tensor(np.array([[1], [0], [1], [0]]), autograd=True)
model = Sequential([Linear(2,3), Linear(3,1)])
optimizer = SGD(model.get_parameters(), alpha=0.05)
for i in range(10): # epochs
y_hat = model.forward(x) # forward propagation
loss = ((y_hat - y)*(y_hat - y)).sum(0) # loss
loss.backward(Tensor(np.ones_like(loss.data))) # back propagation
optimizer.step()
print(loss)
[1.39435371] [1.03442471] [0.80333761] [0.60197476] [0.46415449] [0.34221874] [0.25943595] [0.1908049] [0.14431529] [0.10740191]
This is very Similar to PyTorch, Amazing!
Let's also implement loss functions as layers:
class MSELoss(Layer):
def __init__(self):
super().__init__()
def forward(self, y_hat, y):
return ((y_hat - y) * (y_hat - y)).sum(0)
import numpy as np
np.random.seed(0)
x = Tensor(np.array([[0, 0], [0,1], [1,0], [1,1]]), autograd=True)
y = Tensor(np.array([[1], [0], [1], [0]]), autograd=True)
model = Sequential([Linear(2,3), Linear(3,1)])
loss = MSELoss()
optimizer = SGD(parameters=model.get_parameters(), alpha=0.05)
for i in range(10):
y_hat = model.forward(x)
l = loss.forward(y_hat, y)
l.backward(Tensor(np.ones_like(l.data)))
optimizer.step()
print(l)
[1.6813686] [0.95192748] [0.72454581] [0.57489823] [0.45840608] [0.36465316] [0.28883237] [0.22760439] [0.17835522] [0.13895393]
Autograd is doing all of the back propagation and the forward propagation steps are organized in classes to ensure smooth propagation.
Autograd will ensure that we can piece together different types of layers without losing sights of the underlying relationships and flowing gradients. This is the main feature of modern frameworks. They eliminate the need to handwrite every mathematical operation for forward/backward propagation.
Viewing a framework as merely an autograd system + a list of layers, loss functions, optimizers will help us learn them. We should take a moment to read through the list of layers and optimizers for the different frameworks we have:
We've added a quick hack to be able to call .backward()
on loss without passing a Tensor of 1
s everytime.
import numpy as np
class Tensor(object):
def __init__(self, data, autograd=False, parents=None, creation_op=None, id=None):
self.data = np.array(data)
self.autograd = autograd
self.parents = parents
self.children = {}
self.creation_op = creation_op
self.grad = None
if (id is None): id = np.random.randint(100)
self.id = id
if (parents is not None):
for parent in parents:
if (self.id not in parent.children):
parent.children[self] = 1
else:
parent.children[self] += 1
def all_grads_propagated(self):
for _, grads_count in self.children.items():
if (grads_count != 0): return False
return True
def __add__(self, other):
if (self.autograd and other.autograd):
return Tensor(self.data + other.data,
autograd=True,
parents=[self, other],
creation_op="+")
return Tensor(self.data + other.data)
def __sub__(self, other):
if (self.autograd and other.autograd):
return Tensor(self.data-other.data,
autograd=True,
parents=[self, other],
creation_op="-")
return Tensor(self.data-other.data)
def __mul__(self, other):
if (self.autograd and other.autograd):
return Tensor(self.data * other.data,
autograd=True,
parents=[self, other],
creation_op="*")
return Tensor(self.data * other.data)
def sum(self, dim):
if (self.autograd):
return Tensor(self.data.sum(dim),
autograd=True,
parents=[self],
creation_op="sum_" + str(dim))
return Tensor(self.data.sum(dim))
def __neg__(self):
if (self.autograd):
return Tensor(self.data * -1,
autograd=True,
parents=[self],
creation_op="neg")
return Tensor(self.data * -1)
def __repr__(self):
return str('Tensor(' + self.id.__repr__() + ')')
def __str__(self):
return str(self.data.__str__())
def expand(self, dim, copies):
trans_cmd = list(range(0, len(self.data.shape)))
trans_cmd.insert(dim, len(self.data.shape))
new_shape = list(self.data.shape) + [copies]
new_data = self.data.repeat(copies).reshape(new_shape)
new_data = new_data.transpose(trans_cmd)
if (self.autograd):
return Tensor(new_data,
autograd=True,
parents=[self],
creation_op="expand_"+str(dim))
return Tensor(new_data)
def transpose(self):
if (self.autograd):
return Tensor(self.data.transpose(),
autograd=True,
parents=[self],
creation_op="T")
return Tensor(self.data.transpose())
def mm(self, x):
if (self.autograd):
return Tensor(self.data.dot(x.data),
autograd=True,
parents=[self, x],
creation_op="mm")
return Tensor(self.data.dot(x.data))
def sigmoid(self):
if (self.autograd):
return Tensor(1/(1+np.exp(-self.data)),
autograd=True,
parents=[self],
creation_op="sigmoid")
return Tensor(1/(1+np.exp(-self.data)))
def tanh(self):
if (self.autograd):
return Tensor(np.tanh(self.data),
autograd=True,
parents=[self],
creation_op="tanh")
return Tensor(np.tanh(self.data))
def index_select(self, indices):
if (self.autograd):
new = Tensor(self.data[indices.data],
autograd=True,
parents=[self],
creation_op="index_select")
new.index_select_indices = indices
return new
return Tensor(self.data[indices.data])
def cross_entropy(self, target_indices):
temp = np.exp(self.data)
softmax_output = temp / np.sum(temp, axis=len(self.data.shape)-1, keepdims=True)
t = target_indices.data.flatten()
p = softmax_output.reshape(len(t), -1)
target_dist = np.eye(p.shape[1])[t]
loss = - (np.log(p) * (target_dist)).sum(1).mean()
if (self.autograd):
out = Tensor(loss,
autograd=True,
parents=[self],
creation_op="cross_entropy")
out.softmax_output = softmax_output
out.target_dist = target_dist
return out
return Tensor(loss)
def backward(self, grad=None, grad_origin=None):
if (self.autograd):
if (grad == None):
grad = Tensor(np.ones_like(self.data))
if (grad_origin is not None):
if (self.children[grad_origin] == 0):
raise Exception("cannot backprop more than once")
else:
self.children[grad_origin] -= 1
if (self.grad is None):
self.grad = grad
else:
self.grad += grad
if ((self.parents is not None) and (self.all_grads_propagated() or grad_origin is None)):
if (self.creation_op == "+"):
self.parents[0].backward(self.grad, grad_origin=self)
self.parents[1].backward(self.grad, grad_origin=self)
if (self.creation_op == "neg"):
self.parents[0].backward(self.grad.__neg__())
if (self.creation_op == '-'):
self.parents[0].backward(self.grad, grad_origin=self)
self.parents[1].backward(self.grad.__neg__(), grad_origin=self)
if (self.creation_op == '*'):
self.parents[0].backward(self.grad*self.parents[1], grad_origin=self)
self.parents[1].backward(self.grad*self.parents[0], grad_origin=self)
if (self.creation_op == 'mm'):
activation = self.parents[0] # usually an activation function
weights = self.parents[1] # usually a weights matrix
activation.backward(self.grad.mm(weights.transpose()))
weights.backward(self.grad.transpose().mm(activation).transpose())
if (self.creation_op == 'T'):
self.parents[0].backward(self.grad.transpose())
if ("sum" in self.creation_op):
dim = int(self.creation_op.split("_")[1])
ds = self.parents[0].data.shape[dim]
self.parents[0].backward(self.grad.expand(dim, ds))
if ("expand" in self.creation_op):
dim = int(self.creation_op.split("_")[1])
self.parents[0].backward(self.grad.sum(dim))
if (self.creation_op == 'sigmoid'):
ones = Tensor(np.ones_like(self.grad.data))
self.parents[0].backward(self.grad * (self * (ones - self)))
if (self.creation_op == 'tanh'):
ones = Tensor(np.ones_like(self.grad.data))
self.parents[0].backward(self.grad * (ones - (self * self)))
if (self.creation_op == 'index_select'):
new_grad = np.zeros_like(self.parents[0].data)
indices_ = self.index_select_indices.data.flatten()
grad_ = grad.data.reshape(len(indices_), -1)
for i in range(len(indices_)):
new_grad[indices_[i]] += grad_[i]
self.parents[0].backward(Tensor(new_grad))
if (self.creation_op == 'cross_entropy'):
dx = self.softmax_output - self.target_dist
self.parents[0].backward(Tensor(dx))
Hopefully, this feels fairly routine:
class Tanh(Layer):
def __init__(self):
super().__init__()
def forward(self, input):
return input.tanh()
class Sigmoid(Layer):
def __init__(self):
super().__init__()
def forward(self, input):
return input.sigmoid()
Let's try out the new nonlinearities:
x = Tensor(np.array([[0, 0], [0,1], [1,0], [1,1]]), autograd=True)
y = Tensor(np.array([[1], [0], [1], [0]]), autograd=True)
model = Sequential([Linear(2,3), Tanh(), Linear(3,1), Sigmoid()])
loss = MSELoss()
optimizer = SGD(parameters=model.get_parameters(), alpha=1)
for i in range(10):
y_hat = model.forward(x)
l = loss.forward(y_hat, y)
l.backward(Tensor(np.ones_like(l.data)))
optimizer.step()
print(l)
[1.10815212] [0.54905107] [0.31290284] [0.18050833] [0.11220714] [0.07893695] [0.06391673] [0.05358047] [0.04600529] [0.04022943]
As we can see, we can drop the new Tanh()
and Sigmoid()
Nonlinearities in Sequential
and the network knows exactly what to do with them. Next, we'll abstract out and implement RNN layers in our framework, to do that, we need 3 new layer types:
Word embeddings are vectors mapped to words that we can forward propagate into a neural network. If we have a vocabulary of 200 words, we'll have 200
embeddings.
First, let's initialize a list of the right length for word embeddings:
class Embedding(Layer):
def __init__(self, vocab_size, dim):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
# this initialization style is a convention from word2vec
weight = (np.random.rand(vocab_size, dim) - 0.5) / dim
The weight matrix has a row (vector) for each unique word in the vocabulary. Forward propagation always starts with the question "How will the input the inputs be encoded?", but we forward propagate word indices, not words:
identity = np.eye(5)
identity
array([[1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.], [0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.]])
identity[np.array([[1,2,3,4], [2,3,4,0]])]
array([[[0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.], [0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.]], [[0., 0., 1., 0., 0.], [0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.], [1., 0., 0., 0., 0.]]])
Before doing anything with the embedding layer, autograd must support indexing. We need to make sure that during backpropagation, the gradients are placed in the same rows as were indexed into for forward propagation. This requires that we keep around whatever indices you passed in.
So we can place each gradient in the appropriate location during back propagation with a simple for
loop:
# Added To class
def index_select(self, indices):
if (self.autograd):
new = Tensor(self.data[indices.data],
autograd=True,
parents=[self],
creation_op="index_select")
new.index_select_indices = indices
return new
return Tensor(self.data[indices.data])
Then, during .backprop()
, initialize a new gradient of the correct size.
2. Flatten the indices so we can itereate through them.
3. Collapse grad_
to a simple list of rows.
4. Interate through each index, add it into the correct row of the new gradient you're creating, and backpropagate it into parents[0]
.
Example:
x = Tensor(np.eye(5), autograd=True)
x.data
array([[1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.], [0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.]])
x.index_select(Tensor([[1,2,3], [2,3,4]])).backward()
x.grad.data
array([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 1.], [2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.], [1., 1., 1., 1., 1.]])
class Embedding(Layer):
def __init__(self, vocab_size, dim):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
# this initialization style is a convention from word2vec
weight = (np.random.rand(vocab_size, dim) - 0.5) / dim
self.weight = Tensor(weight, autograd=True)
self.parameters.append(self.weight)
def forward(self, input):
# input is word indices
return self.weight.index_select(input)
x = Tensor(np.array([1,2,1,2]), autograd=True)
y = Tensor(np.array([[0], [1], [0], [1]]), autograd=True)
embed = Embedding(vocab_size=5, dim=3)
model = Sequential([embed, Tanh(), Linear(3,1), Sigmoid()])
loss = MSELoss()
optimizer = SGD(parameters=model.get_parameters(), alpha=0.5)
for i in range(10): # epochs
y_hat = model.forward(x)
l = loss.forward(y_hat, y)
l.backward(Tensor(np.ones_like(l.data)))
optimizer.step()
print(l)
[1.15164466] [0.40309683] [0.20199962] [0.12654533] [0.08969294] [0.06850048] [0.0549488] [0.04562442] [0.03885794] [0.03374584]
In this neural network, we learn to correlate inputs 1
and 2
with predictions 0
and 1
. In theory, indices 1
& 2
could correspond to token indices (like words or characters or objects).
# added to `Tensor`
def cross_entropy(self, target_indices):
temp = np.exp(self.data)
softmax_output = temp / np.sum(temp, axis=len(self.data.shape)-1, keepdims=True)
t = target_indices.data.flatten()
p = softmax_output.reshape(len(t), -1)
target_dist = np.eye(p.shape[1])[t]
loss = - (np.log(p) * (target_dist)).sum(1).mean()
if (self.autograd):
out = Tensor(loss,
autograd=True,
parents=[self],
creation_op="cross_entropy")
out.softmax_output = softmax_output
out.target_dist = target_dist
return out
return Tensor(loss)
# Cross Entropy Layer
class CrossEntropyLoss(object):
def __init__(self):
super().__init__()
def forward(self, input, target):
return input.cross_entropy(target)
An example:
import numpy as np
np.random.seed(0)
x = Tensor(np.array([1,2,1,2]), autograd=True)
y = Tensor(np.array([0,1,0,1]), autograd=True)
model = Sequential([Embedding(3,3), Tanh(), Linear(3,4)])
loss = CrossEntropyLoss()
optimizer = SGD(parameters=model.get_parameters(), alpha=0.1)
for i in range(10):
y_hat = model.forward(x)
l = loss.forward(y_hat, y)
l.backward(Tensor(np.ones_like(l.data)))
optimizer.step()
print(l)
0.1885620377278218 0.16792356657516766 0.15086902981053774 0.13661027751993166 0.1245597425659242 0.11427427141364929 0.10541566875575809 0.09772282456212722 0.0909918671530601 0.08506189900243766
One noticable thing about this loss that's different from the others is that both the final softmax & the computation of the loss are within the low class.
When we design a network to be trained using cross entropy, we can leave off the softmax from the forward propagation step and call a cross entropy class that will automatically perform the softmax as part of the loss function. It's much faster to calculate the gradient of softmax and negative-log likelihood together in a cross-entropy function than to forward propagate and backpropagate them separately in two different modules.
Let's create one more layer that's the composition of multiple smaller layer types. This layer is the Recurrent Layer. We'll construct it using 3 Linear layers.
The .forward()
method will take both the output from the previous hidden state and the input from the current training data.
class RNNCell(Layer):
def __init__(self, n_inputs, n_hidden, n_output, activation='sigmoid'):
super().__init__()
self.n_inputs = n_inputs
self.n_hidden = n_hidden
self.n_output = n_output
if (activation == 'sigmoid'):
self.activation = Sigmoid()
elif (activation == 'tanh'):
self.activation = Tanh()
else:
raise Exception("Non-Linearity not found")
self.w_ih = Linear(n_inputs, n_hidden)
self.w_hh = Linear(n_hidden, n_hidden)
self.w_ho = Linear(n_hidden, n_output)
self.parameters += self.w_ih.get_parameters()
self.parameters += self.w_hh.get_parameters()
self.parameters += self.w_ho.get_parameters()
def forward(self, input, hidden):
from_prev_hidden = self.w_hh.forward(hidden)
combined = self.w_ih.forward(input) + from_prev_hidden
new_hidden = self.activation.forward(combined)
output = self.w_ho.forward(new_hidden)
return output, new_hidden
def init_hidden(self, batch_size=1):
return Tensor(np.zeros((batch_size,self.n_hidden)),autograd=True)
RNNs have a state vector that passes from timestep to timestep. In this case, it's the variable hidden
, which is both an input parameter and an output variable to the forward function.
RNNs also have several weight matrices:
An activation
input parameter defines which nonlinearity is applied to hidden vectors at each timestep.
Let's train the network:
import sys, random, math
from collections import Counter
import numpy as np
f = open('static/data/tasksv11/en/qa1_single-supporting-fact_train.txt', 'r')
raw = f.readlines()
f.close()
tokens = list()
for line in raw[0:1000]:
tokens.append(line.lower().replace("\n", "").split(" ")[1:])
new_tokens = list()
for line in tokens:
new_tokens.append(['-'] * (6 - len(line)) + line)
tokens = new_tokens
vocab = set()
for sent in tokens:
for word in sent:
vocab.add(word)
vocab = list(vocab)
word2index = {}
for i, word in enumerate(vocab):
word2index[word] = i
def words2indices(sentence):
idx = list()
for word in sentence:
idx.append(word2index[word])
return idx
indices = list()
for line in tokens:
idx = list()
for w in line:
idx.append(word2index[w])
indices.append(idx)
data = np.array(indices)
Now we can initialize the recurrent layer with an embedding input and train a network to solve the same task as in the previous chapter. We should note that this network is slightly more complex:
embed = Embedding(vocab_size=len(vocab), dim=16)
model = RNNCell(n_inputs=16, n_hidden=16, n_output=len(vocab))
loss = CrossEntropyLoss()
params = model.get_parameters() + embed.get_parameters()
optimizer = SGD(parameters=params, alpha=0.05)
Cell is a conventional name given to RNNs when they're implementing a single recurrence. If we created another layer that provided the ability to configure arbitrary numbers of cells together, It would be called an RNN, and n_layers
would be an input parameter.
for iter in range(1000):
batch_size = 100
total_loss = 0
hidden = model.init_hidden(batch_size=batch_size)
for t in range(5):
input = Tensor(data[0:batch_size, t], autograd=True)
rnn_input = embed.forward(input=input)
output, hidden = model.forward(input=rnn_input, hidden=hidden)
target = Tensor(data[0:batch_size, t+1], autograd=True)
l = loss.forward(output, target)
l.backward()
optimizer.step()
total_loss += l.data
if (iter % 200 == 0):
p_correct = (target.data == np.argmax(output.data, axis=1)).mean()
print_loss = total_loss / (len(data) / batch_size)
print("Loss:", print_loss, "% Correct: ", p_correct)
Loss: 0.4210781038536967 % Correct: 0.0 Loss: 0.17030485013740324 % Correct: 0.27 Loss: 0.14961780443925604 % Correct: 0.36 Loss: 0.1390113005612828 % Correct: 0.36 Loss: 0.13628065998222028 % Correct: 0.35
Let's try to predict using the trained model:
batch_size = 1
hidden = model.init_hidden(batch_size=batch_size)
for t in range(5):
input = Tensor(data[0:batch_size, t], autograd=True)
rnn_input = embed.forward(input=input)
output, hidden = model.forward(input=rnn_input, hidden=hidden)
target = Tensor(data[0:batch_size, t+1], autograd=True)
l = loss.forward(output, target)
ctx = ""
for idx in data[0:batch_size][0][0:-1]:
ctx += vocab[idx] + " "
print("Context: ", ctx)
print("Pred: ", vocab[output.data.argmax()])
Context: - mary moved to the Pred: office.
The Neural Network learns to predict the first 100
examples of the training dataset with an accuracy of over 37%
. It predicts a plausible location for Mary to be moving toward.
Frameworks can make our code more readable, faster to write, and faster to execute (through built-in optimizations).
This chapter will prepare us to use and extend industry standard frameworks like PyTorch or TensorFlow. The skills we've learned in this chapter will be the most valuable ones from this book. We highly recommend diving in PyTorch
after finishing this book.