This notebook was part of Lesson 7 of the Practical Deep Learning for Coders course.
We were using RNNs as part of our language model in the previous lesson. Today, we will dive into more details of what RNNs are and how they work. We will do this using the problem of trying to predict the English word version of numbers.
Let's predict what should come next in this sequence:
eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve...
Jeremy created this synthetic dataset to have a better way to check if things are working, to debug, and to understand what was going on. When experimenting with new ideas, it can be nice to have a smaller dataset to do so, to quickly get a sense of whether your ideas are promising (for other examples, see Imagenette and Imagewoof) This English word numbers will serve as a good dataset for learning about RNNs. Our task today will be to predict which word comes next when counting.
Parameters are numbers that are learned. Activations are numbers that are calculated (by affine functions & element-wise non-linearities).
When you learn about any new concept in deep learning, ask yourself: is this a parameter or an activation?
Note to self: Point out the hidden state, going from the version without a for-loop to the for loop. This is the step where people get confused.
from fastai.text import *
bs=64
path = untar_data(URLs.HUMAN_NUMBERS)
path.ls()
[PosixPath('/home/racheltho/.fastai/data/human_numbers/models'), PosixPath('/home/racheltho/.fastai/data/human_numbers/valid.txt'), PosixPath('/home/racheltho/.fastai/data/human_numbers/train.txt')]
def readnums(d): return [', '.join(o.strip() for o in open(path/d).readlines())]
train.txt gives us a sequence of numbers written out as English words:
train_txt = readnums('train.txt'); train_txt[0][:80]
'one, two, three, four, five, six, seven, eight, nine, ten, eleven, twelve, thirt'
valid_txt = readnums('valid.txt'); valid_txt[0][-80:]
' nine thousand nine hundred ninety eight, nine thousand nine hundred ninety nine'
train = TextList(train_txt, path=path)
valid = TextList(valid_txt, path=path)
src = ItemLists(path=path, train=train, valid=valid).label_for_lm()
data = src.databunch(bs=bs)
train[0].text[:80]
'xxbos one , two , three , four , five , six , seven , eight , nine , ten , eleve'
len(data.valid_ds[0][0].data)
13017
bptt
stands for back-propagation through time. This tells us how many steps of history we are considering.
data.bptt, len(data.valid_dl)
(70, 3)
We have 3 batches in our validation set:
13017 tokens, with about ~70 tokens in about a line of text, and 64 lines of text per batch.
13017/70/bs
2.905580357142857
We will store each batch in a separate variable, so we can walk through this to understand better what the RNN does at each step:
it = iter(data.valid_dl)
x1,y1 = next(it)
x2,y2 = next(it)
x3,y3 = next(it)
it.close()
x1
tensor([[ 2, 19, 11, ..., 36, 9, 19], [ 9, 19, 11, ..., 24, 20, 9], [11, 27, 18, ..., 9, 19, 11], ..., [20, 11, 20, ..., 11, 20, 10], [20, 11, 20, ..., 24, 9, 20], [20, 10, 26, ..., 20, 11, 20]], device='cuda:0')
numel()
is a PyTorch method to return the number of elements in a tensor:
x1.numel()+x2.numel()+x3.numel()
13440
x1.shape, y1.shape
(torch.Size([64, 70]), torch.Size([64, 70]))
x2.shape, y2.shape
(torch.Size([64, 70]), torch.Size([64, 70]))
x3.shape, y3.shape
(torch.Size([64, 70]), torch.Size([64, 70]))
v = data.valid_ds.vocab
v.itos
['xxunk', 'xxpad', 'xxbos', 'xxeos', 'xxfld', 'xxmaj', 'xxup', 'xxrep', 'xxwrep', ',', 'hundred', 'thousand', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'twenty', 'thirty', 'forty', 'fifty', 'sixty', 'seventy', 'eighty', 'ninety', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen']
x1[:,0]
tensor([ 2, 9, 11, 12, 13, 11, 10, 9, 10, 14, 19, 25, 19, 15, 16, 11, 19, 9, 10, 9, 19, 25, 19, 11, 19, 11, 10, 9, 19, 20, 11, 26, 20, 23, 20, 20, 24, 20, 11, 14, 11, 11, 9, 14, 9, 20, 10, 20, 35, 17, 11, 10, 9, 17, 9, 20, 10, 20, 11, 20, 11, 20, 20, 20], device='cuda:0')
y1[:,0]
tensor([19, 19, 27, 10, 9, 12, 32, 19, 26, 10, 11, 15, 11, 10, 9, 15, 11, 19, 26, 19, 11, 18, 11, 18, 9, 18, 21, 19, 10, 10, 20, 9, 11, 16, 11, 11, 13, 11, 13, 9, 13, 14, 20, 10, 20, 11, 24, 11, 9, 9, 16, 17, 20, 10, 20, 11, 24, 11, 19, 9, 19, 11, 11, 10], device='cuda:0')
v.itos[9], v.itos[11], v.itos[12], v.itos[13], v.itos[10]
(',', 'thousand', 'one', 'two', 'hundred')
v.textify(x1[0])
'xxbos eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight'
v.textify(x1[1])
', eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine ,'
v.textify(x2[1])
'eight thousand sixty , eight thousand sixty one , eight thousand sixty two , eight thousand sixty three , eight thousand sixty four , eight thousand sixty five , eight thousand sixty six , eight thousand sixty seven , eight thousand sixty eight , eight thousand sixty nine , eight thousand seventy , eight thousand seventy one , eight thousand seventy two , eight thousand seventy three , eight thousand'
v.textify(y1[0])
'eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand'
v.textify(x2[0])
'thousand eighteen , eight thousand nineteen , eight thousand twenty , eight thousand twenty one , eight thousand twenty two , eight thousand twenty three , eight thousand twenty four , eight thousand twenty five , eight thousand twenty six , eight thousand twenty seven , eight thousand twenty eight , eight thousand twenty nine , eight thousand thirty , eight thousand thirty one , eight thousand thirty two ,'
v.textify(x3[0])
'eight thousand thirty three , eight thousand thirty four , eight thousand thirty five , eight thousand thirty six , eight thousand thirty seven , eight thousand thirty eight , eight thousand thirty nine , eight thousand forty , eight thousand forty one , eight thousand forty two , eight thousand forty three , eight thousand forty four , eight thousand forty five , eight thousand forty six , eight'
v.textify(x1[1])
', eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine ,'
v.textify(x2[1])
'eight thousand sixty , eight thousand sixty one , eight thousand sixty two , eight thousand sixty three , eight thousand sixty four , eight thousand sixty five , eight thousand sixty six , eight thousand sixty seven , eight thousand sixty eight , eight thousand sixty nine , eight thousand seventy , eight thousand seventy one , eight thousand seventy two , eight thousand seventy three , eight thousand'
v.textify(x3[1])
'seventy four , eight thousand seventy five , eight thousand seventy six , eight thousand seventy seven , eight thousand seventy eight , eight thousand seventy nine , eight thousand eighty , eight thousand eighty one , eight thousand eighty two , eight thousand eighty three , eight thousand eighty four , eight thousand eighty five , eight thousand eighty six , eight thousand eighty seven , eight thousand eighty'
v.textify(x3[-1])
'ninety , nine thousand nine hundred ninety one , nine thousand nine hundred ninety two , nine thousand nine hundred ninety three , nine thousand nine hundred ninety four , nine thousand nine hundred ninety five , nine thousand nine hundred ninety six , nine thousand nine hundred ninety seven , nine thousand nine hundred ninety eight , nine thousand nine hundred ninety nine xxbos eight thousand one , eight'
data.show_batch(ds_type=DatasetType.Valid)
idx | text |
---|---|
0 | thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine , eight thousand sixty , eight thousand sixty |
1 | eight , eight thousand eighty nine , eight thousand ninety , eight thousand ninety one , eight thousand ninety two , eight thousand ninety three , eight thousand ninety four , eight thousand ninety five , eight thousand ninety six , eight thousand ninety seven , eight thousand ninety eight , eight thousand ninety nine , eight thousand one hundred , eight thousand one hundred one , eight thousand one |
2 | thousand one hundred twenty four , eight thousand one hundred twenty five , eight thousand one hundred twenty six , eight thousand one hundred twenty seven , eight thousand one hundred twenty eight , eight thousand one hundred twenty nine , eight thousand one hundred thirty , eight thousand one hundred thirty one , eight thousand one hundred thirty two , eight thousand one hundred thirty three , eight thousand |
3 | three , eight thousand one hundred fifty four , eight thousand one hundred fifty five , eight thousand one hundred fifty six , eight thousand one hundred fifty seven , eight thousand one hundred fifty eight , eight thousand one hundred fifty nine , eight thousand one hundred sixty , eight thousand one hundred sixty one , eight thousand one hundred sixty two , eight thousand one hundred sixty three |
4 | thousand one hundred eighty three , eight thousand one hundred eighty four , eight thousand one hundred eighty five , eight thousand one hundred eighty six , eight thousand one hundred eighty seven , eight thousand one hundred eighty eight , eight thousand one hundred eighty nine , eight thousand one hundred ninety , eight thousand one hundred ninety one , eight thousand one hundred ninety two , eight thousand |
We will iteratively consider a few different models, building up to a more traditional RNN.
data = src.databunch(bs=bs, bptt=3)
x,y = data.one_batch()
x.shape,y.shape
(torch.Size([64, 3]), torch.Size([64, 3]))
nv = len(v.itos); nv
39
nh=64
def loss4(input,target): return F.cross_entropy(input, target[:,-1])
def acc4 (input,target): return accuracy(input, target[:,-1])
x[:,0]
tensor([13, 13, 10, 9, 18, 9, 11, 11, 13, 19, 16, 23, 24, 9, 12, 9, 13, 14, 15, 11, 10, 22, 15, 9, 10, 14, 11, 16, 10, 28, 11, 9, 20, 9, 15, 15, 11, 18, 10, 28, 23, 24, 9, 16, 10, 16, 19, 20, 12, 10, 22, 16, 17, 17, 17, 11, 24, 10, 9, 15, 16, 9, 18, 11])
Layer names:
i_h
: input to hiddenh_h
: hidden to hiddenh_o
: hidden to outputbn
: batchnormclass Model0(nn.Module):
def __init__(self):
super().__init__()
self.i_h = nn.Embedding(nv,nh) # green arrow
self.h_h = nn.Linear(nh,nh) # brown arrow
self.h_o = nn.Linear(nh,nv) # blue arrow
self.bn = nn.BatchNorm1d(nh)
def forward(self, x):
h = self.bn(F.relu(self.i_h(x[:,0])))
if x.shape[1]>1:
h = h + self.i_h(x[:,1])
h = self.bn(F.relu(self.h_h(h)))
if x.shape[1]>2:
h = h + self.i_h(x[:,2])
h = self.bn(F.relu(self.h_h(h)))
return self.h_o(h)
learn = Learner(data, Model0(), loss_func=loss4, metrics=acc4)
learn.fit_one_cycle(6, 1e-4)
epoch | train_loss | valid_loss | acc4 | time |
---|---|---|---|---|
0 | 3.693647 | 3.621712 | 0.054458 | 00:01 |
1 | 3.087305 | 3.181658 | 0.397518 | 00:01 |
2 | 2.475566 | 2.665241 | 0.442096 | 00:01 |
3 | 2.147892 | 2.392788 | 0.456572 | 00:01 |
4 | 2.026868 | 2.300193 | 0.457031 | 00:01 |
5 | 2.002222 | 2.287001 | 0.456572 | 00:01 |
Let's refactor this to use a for-loop. This does the same thing as before:
class Model1(nn.Module):
def __init__(self):
super().__init__()
self.i_h = nn.Embedding(nv,nh) # green arrow
self.h_h = nn.Linear(nh,nh) # brown arrow
self.h_o = nn.Linear(nh,nv) # blue arrow
self.bn = nn.BatchNorm1d(nh)
def forward(self, x):
h = torch.zeros(x.shape[0], nh).to(device=x.device)
for i in range(x.shape[1]):
h = h + self.i_h(x[:,i])
h = self.bn(F.relu(self.h_h(h)))
return self.h_o(h)
This is the difference between unrolled (what we had before) and rolled (what we have now) RNN diagrams:
learn = Learner(data, Model1(), loss_func=loss4, metrics=acc4)
learn.fit_one_cycle(6, 1e-4)
epoch | train_loss | valid_loss | acc4 | time |
---|---|---|---|---|
0 | 3.555616 | 3.537218 | 0.111673 | 00:01 |
1 | 2.946805 | 2.965743 | 0.403722 | 00:01 |
2 | 2.402015 | 2.497509 | 0.448300 | 00:01 |
3 | 2.089903 | 2.259959 | 0.454733 | 00:01 |
4 | 1.967923 | 2.184094 | 0.455653 | 00:01 |
5 | 1.942891 | 2.173949 | 0.449678 | 00:01 |
Our accuracy is about the same, since we are doing the same thing as before.
Before, we were just predicting the last word in a line of text. Given 70 tokens, what is token 71? That approach was throwing away a lot of data. Why not predict token 2 from token 1, then predict token 3, then predict token 4, and so on? We will modify our model to do this.
data = src.databunch(bs=bs, bptt=20)
x,y = data.one_batch()
x.shape,y.shape
(torch.Size([64, 20]), torch.Size([64, 20]))
class Model2(nn.Module):
def __init__(self):
super().__init__()
self.i_h = nn.Embedding(nv,nh)
self.h_h = nn.Linear(nh,nh)
self.h_o = nn.Linear(nh,nv)
self.bn = nn.BatchNorm1d(nh)
def forward(self, x):
h = torch.zeros(x.shape[0], nh).to(device=x.device)
res = []
for i in range(x.shape[1]):
h = h + self.i_h(x[:,i])
h = F.relu(self.h_h(h))
res.append(self.h_o(self.bn(h)))
return torch.stack(res, dim=1)
learn = Learner(data, Model2(), metrics=accuracy)
learn.fit_one_cycle(10, 1e-4, pct_start=0.1)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 3.672596 | 3.615433 | 0.044815 | 00:01 |
1 | 3.577137 | 3.518008 | 0.149077 | 00:00 |
2 | 3.456851 | 3.419511 | 0.220810 | 00:00 |
3 | 3.332458 | 3.332359 | 0.247301 | 00:00 |
4 | 3.219220 | 3.258093 | 0.289986 | 00:00 |
5 | 3.125784 | 3.203508 | 0.311435 | 00:00 |
6 | 3.054881 | 3.167783 | 0.330895 | 00:00 |
7 | 3.005914 | 3.148063 | 0.333878 | 00:00 |
8 | 2.975965 | 3.140450 | 0.335440 | 00:00 |
9 | 2.960360 | 3.139326 | 0.335369 | 00:00 |
Note that our accuracy is worse now, because we are doing a harder task. When we predict word k (k<70), we have less history to help us then when we were only predicting word 71.
To address this issue, let's keep the hidden state from the previous line of text, so we are not starting over again on each new line of text.
class Model3(nn.Module):
def __init__(self):
super().__init__()
self.i_h = nn.Embedding(nv,nh)
self.h_h = nn.Linear(nh,nh)
self.h_o = nn.Linear(nh,nv)
self.bn = nn.BatchNorm1d(nh)
self.h = torch.zeros(bs, nh).cuda()
def forward(self, x):
res = []
h = self.h
for i in range(x.shape[1]):
h = h + self.i_h(x[:,i])
h = F.relu(self.h_h(h))
res.append(self.bn(h))
self.h = h.detach()
res = torch.stack(res, dim=1)
res = self.h_o(res)
return res
learn = Learner(data, Model3(), metrics=accuracy)
learn.fit_one_cycle(20, 3e-3)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 3.578149 | 3.480786 | 0.125142 | 00:00 |
1 | 3.213412 | 2.848152 | 0.413991 | 00:00 |
2 | 2.549868 | 1.958330 | 0.464773 | 00:00 |
3 | 1.978741 | 1.906839 | 0.355540 | 00:00 |
4 | 1.669468 | 1.802589 | 0.461719 | 00:00 |
5 | 1.469962 | 1.814068 | 0.474361 | 00:00 |
6 | 1.292922 | 1.659862 | 0.490483 | 00:00 |
7 | 1.105973 | 1.608031 | 0.509375 | 00:00 |
8 | 0.931592 | 1.382533 | 0.543679 | 00:00 |
9 | 0.778177 | 1.407267 | 0.553196 | 00:00 |
10 | 0.666087 | 1.443596 | 0.572301 | 00:00 |
11 | 0.598070 | 1.355380 | 0.588565 | 00:00 |
12 | 0.526167 | 1.382941 | 0.586861 | 00:00 |
13 | 0.469985 | 1.385877 | 0.591903 | 00:00 |
14 | 0.426540 | 1.443933 | 0.577699 | 00:00 |
15 | 0.394259 | 1.449048 | 0.578409 | 00:00 |
16 | 0.372871 | 1.348665 | 0.604190 | 00:00 |
17 | 0.356916 | 1.423057 | 0.587855 | 00:00 |
18 | 0.346191 | 1.421857 | 0.585938 | 00:00 |
19 | 0.340383 | 1.424547 | 0.587145 | 00:00 |
Now we are getting greater accuracy than before!
Let's refactor the above to use PyTorch's RNN. This is what you would use in practice, but now you know the inside details!
class Model4(nn.Module):
def __init__(self):
super().__init__()
self.i_h = nn.Embedding(nv,nh)
self.rnn = nn.RNN(nh,nh, batch_first=True)
self.h_o = nn.Linear(nh,nv)
self.bn = BatchNorm1dFlat(nh)
self.h = torch.zeros(1, bs, nh).cuda()
def forward(self, x):
res,h = self.rnn(self.i_h(x), self.h)
self.h = h.detach()
return self.h_o(self.bn(res))
learn = Learner(data, Model4(), metrics=accuracy)
learn.fit_one_cycle(20, 3e-3)
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
1 | 3.451432 | 3.268344 | 0.224148 |
2 | 2.974938 | 2.456569 | 0.466051 |
3 | 2.316732 | 1.946969 | 0.465625 |
4 | 1.866151 | 1.991952 | 0.314702 |
5 | 1.618516 | 1.802403 | 0.437216 |
6 | 1.411517 | 1.731107 | 0.436293 |
7 | 1.171916 | 1.655979 | 0.504048 |
8 | 0.965887 | 1.579963 | 0.522088 |
9 | 0.797046 | 1.479819 | 0.565057 |
10 | 0.659378 | 1.487831 | 0.579048 |
11 | 0.553282 | 1.441922 | 0.597798 |
12 | 0.475167 | 1.498148 | 0.600781 |
13 | 0.416131 | 1.546984 | 0.606463 |
14 | 0.372395 | 1.594261 | 0.607386 |
15 | 0.337093 | 1.578321 | 0.613352 |
16 | 0.311385 | 1.580973 | 0.623366 |
17 | 0.292869 | 1.625745 | 0.618253 |
18 | 0.279486 | 1.623960 | 0.626065 |
19 | 0.270054 | 1.682090 | 0.611719 |
20 | 0.263857 | 1.675676 | 0.614702 |
When you have long time scales and deeper networks, these become impossible to train. One way to address this is to add mini-NN to decide how much of the green arrow and how much of the orange arrow to keep. These mini-NNs can be GRUs or LSTMs. We will cover more details of this in a later lesson.
class Model5(nn.Module):
def __init__(self):
super().__init__()
self.i_h = nn.Embedding(nv,nh)
self.rnn = nn.GRU(nh, nh, 2, batch_first=True)
self.h_o = nn.Linear(nh,nv)
self.bn = BatchNorm1dFlat(nh)
self.h = torch.zeros(2, bs, nh).cuda()
def forward(self, x):
res,h = self.rnn(self.i_h(x), self.h)
self.h = h.detach()
return self.h_o(self.bn(res))
learn = Learner(data, Model5(), metrics=accuracy)
learn.fit_one_cycle(10, 1e-2)
epoch | train_loss | valid_loss | accuracy |
---|---|---|---|
1 | 2.864854 | 2.314943 | 0.454545 |
2 | 1.798988 | 1.357116 | 0.629688 |
3 | 0.932729 | 1.307463 | 0.796733 |
4 | 0.451969 | 1.329699 | 0.788636 |
5 | 0.225787 | 1.293570 | 0.800142 |
6 | 0.118085 | 1.265926 | 0.803338 |
7 | 0.065306 | 1.207096 | 0.806960 |
8 | 0.038098 | 1.205361 | 0.813920 |
9 | 0.024069 | 1.239411 | 0.807813 |
10 | 0.017078 | 1.253409 | 0.807102 |
In the previous lesson, we were essentially swapping out self.h_o
with a classifier in order to do classification on text.
RNNs are just a refactored, fully-connected neural network.
You can use the same approach for any sequence labeling task (part of speech, classifying whether material is sensitive,..)