Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.
Demo of a simple RNN for sentiment classification (here: a binary classification problem with two labels, positive and negative). Note that a simple RNN usually doesn't work very well due to vanishing and exploding gradient problems. Also, this implementation uses padding for dealing with variable size inputs. Hence, the shorter the sentence, the more <pad>
placeholders will be added to match the length of the longest sentence in a batch. However, in this example notebook, nn.utils.rnn.pack_padded_sequence
will be used such that their won't be an actual computation carried out when the sentence ends (i.e., padding tokens will be ignored).
Note that this RNN trains about 4 times faster than the equivalent without packed sequences, ./rnn_simple_imdb.ipynb.
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
import torch
import torch.nn.functional as F
from torchtext import data
from torchtext import datasets
import time
import random
torch.backends.cudnn.deterministic = True
Sebastian Raschka CPython 3.7.1 IPython 7.4.0 torch 1.0.1.post2
RANDOM_SEED = 123
torch.manual_seed(RANDOM_SEED)
VOCABULARY_SIZE = 20000
LEARNING_RATE = 1e-4
BATCH_SIZE = 128
NUM_EPOCHS = 15
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
EMBEDDING_DIM = 128
HIDDEN_DIM = 256
OUTPUT_DIM = 1
Load the IMDB Movie Review dataset:
TEXT = data.Field(tokenize='spacy',
include_lengths=True) # necessary for packed_padded_sequence
LABEL = data.LabelField(dtype=torch.float)
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
train_data, valid_data = train_data.split(random_state=random.seed(RANDOM_SEED),
split_ratio=0.8)
print(f'Num Train: {len(train_data)}')
print(f'Num Valid: {len(valid_data)}')
print(f'Num Test: {len(test_data)}')
Num Train: 20000 Num Valid: 5000 Num Test: 25000
Build the vocabulary based on the top "VOCABULARY_SIZE" words:
TEXT.build_vocab(train_data, max_size=VOCABULARY_SIZE)
LABEL.build_vocab(train_data)
print(f'Vocabulary size: {len(TEXT.vocab)}')
print(f'Number of classes: {len(LABEL.vocab)}')
Vocabulary size: 20002 Number of classes: 2
The TEXT.vocab dictionary will contain the word counts and indices. The reason why the number of words is VOCABULARY_SIZE + 2 is that it contains to special tokens for padding and unknown words: <unk>
and <pad>
.
Make dataset iterators:
train_loader, valid_loader, test_loader = data.BucketIterator.splits(
(train_data, valid_data, test_data),
batch_size=BATCH_SIZE,
sort_within_batch=True, # necessary for packed_padded_sequence
device=DEVICE)
Testing the iterators (note that the number of rows depends on the longest document in the respective batch):
print('Train')
for batch in train_loader:
print(f'Text matrix size: {batch.text[0].size()}')
print(f'Target vector size: {batch.label.size()}')
break
print('\nValid:')
for batch in valid_loader:
print(f'Text matrix size: {batch.text[0].size()}')
print(f'Target vector size: {batch.label.size()}')
break
print('\nTest:')
for batch in test_loader:
print(f'Text matrix size: {batch.text[0].size()}')
print(f'Target vector size: {batch.label.size()}')
break
Train Text matrix size: torch.Size([132, 128]) Target vector size: torch.Size([128]) Valid: Text matrix size: torch.Size([61, 128]) Target vector size: torch.Size([128]) Test: Text matrix size: torch.Size([42, 128]) Target vector size: torch.Size([128])
import torch.nn as nn
class RNN(nn.Module):
def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
super().__init__()
self.embedding = nn.Embedding(input_dim, embedding_dim)
self.rnn = nn.RNN(embedding_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, text, text_length):
#[sentence len, batch size] => [sentence len, batch size, embedding size]
embedded = self.embedding(text)
packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, text_length)
#[sentence len, batch size, embedding size] =>
# output: [sentence len, batch size, hidden size]
# hidden: [1, batch size, hidden size]
output, hidden = self.rnn(packed)
return self.fc(hidden.squeeze(0)).view(-1)
INPUT_DIM = len(TEXT.vocab)
torch.manual_seed(RANDOM_SEED)
model = RNN(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
def compute_binary_accuracy(model, data_loader, device):
model.eval()
correct_pred, num_examples = 0, 0
with torch.no_grad():
for batch_idx, batch_data in enumerate(data_loader):
text, text_lengths = batch_data.text
logits = model(text, text_lengths)
predicted_labels = (torch.sigmoid(logits) > 0.5).long()
num_examples += batch_data.label.size(0)
correct_pred += (predicted_labels == batch_data.label.long()).sum()
return correct_pred.float()/num_examples * 100
start_time = time.time()
for epoch in range(NUM_EPOCHS):
model.train()
for batch_idx, batch_data in enumerate(train_loader):
text, text_lengths = batch_data.text
### FORWARD AND BACK PROP
logits = model(text, text_lengths)
cost = F.binary_cross_entropy_with_logits(logits, batch_data.label)
optimizer.zero_grad()
cost.backward()
### UPDATE MODEL PARAMETERS
optimizer.step()
### LOGGING
if not batch_idx % 50:
print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '
f'Batch {batch_idx:03d}/{len(train_loader):03d} | '
f'Cost: {cost:.4f}')
with torch.set_grad_enabled(False):
print(f'training accuracy: '
f'{compute_binary_accuracy(model, train_loader, DEVICE):.2f}%'
f'\nvalid accuracy: '
f'{compute_binary_accuracy(model, valid_loader, DEVICE):.2f}%')
print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')
print(f'Total Training Time: {(time.time() - start_time)/60:.2f} min')
print(f'Test accuracy: {compute_binary_accuracy(model, test_loader, DEVICE):.2f}%')
Epoch: 001/015 | Batch 000/157 | Cost: 0.6821 Epoch: 001/015 | Batch 050/157 | Cost: 0.6892 Epoch: 001/015 | Batch 100/157 | Cost: 0.6327 Epoch: 001/015 | Batch 150/157 | Cost: 0.6762 training accuracy: 57.08% valid accuracy: 55.60% Time elapsed: 0.13 min Epoch: 002/015 | Batch 000/157 | Cost: 0.6738 Epoch: 002/015 | Batch 050/157 | Cost: 0.6578 Epoch: 002/015 | Batch 100/157 | Cost: 0.6949 Epoch: 002/015 | Batch 150/157 | Cost: 0.6315 training accuracy: 65.47% valid accuracy: 64.64% Time elapsed: 0.28 min Epoch: 003/015 | Batch 000/157 | Cost: 0.6614 Epoch: 003/015 | Batch 050/157 | Cost: 0.5436 Epoch: 003/015 | Batch 100/157 | Cost: 0.6362 Epoch: 003/015 | Batch 150/157 | Cost: 0.5960 training accuracy: 68.87% valid accuracy: 68.08% Time elapsed: 0.41 min Epoch: 004/015 | Batch 000/157 | Cost: 0.6148 Epoch: 004/015 | Batch 050/157 | Cost: 0.5484 Epoch: 004/015 | Batch 100/157 | Cost: 0.5179 Epoch: 004/015 | Batch 150/157 | Cost: 0.6458 training accuracy: 69.85% valid accuracy: 66.82% Time elapsed: 0.54 min Epoch: 005/015 | Batch 000/157 | Cost: 0.5394 Epoch: 005/015 | Batch 050/157 | Cost: 0.6463 Epoch: 005/015 | Batch 100/157 | Cost: 0.5456 Epoch: 005/015 | Batch 150/157 | Cost: 0.5760 training accuracy: 70.96% valid accuracy: 67.28% Time elapsed: 0.67 min Epoch: 006/015 | Batch 000/157 | Cost: 0.5609 Epoch: 006/015 | Batch 050/157 | Cost: 0.5449 Epoch: 006/015 | Batch 100/157 | Cost: 0.5924 Epoch: 006/015 | Batch 150/157 | Cost: 0.5842 training accuracy: 73.84% valid accuracy: 70.90% Time elapsed: 0.81 min Epoch: 007/015 | Batch 000/157 | Cost: 0.5566 Epoch: 007/015 | Batch 050/157 | Cost: 0.5019 Epoch: 007/015 | Batch 100/157 | Cost: 0.4826 Epoch: 007/015 | Batch 150/157 | Cost: 0.5885 training accuracy: 68.89% valid accuracy: 64.76% Time elapsed: 0.94 min Epoch: 008/015 | Batch 000/157 | Cost: 0.5797 Epoch: 008/015 | Batch 050/157 | Cost: 0.5433 Epoch: 008/015 | Batch 100/157 | Cost: 0.4908 Epoch: 008/015 | Batch 150/157 | Cost: 0.5703 training accuracy: 75.42% valid accuracy: 71.44% Time elapsed: 1.07 min Epoch: 009/015 | Batch 000/157 | Cost: 0.5631 Epoch: 009/015 | Batch 050/157 | Cost: 0.4570 Epoch: 009/015 | Batch 100/157 | Cost: 0.6094 Epoch: 009/015 | Batch 150/157 | Cost: 0.6365 training accuracy: 72.83% valid accuracy: 68.32% Time elapsed: 1.20 min Epoch: 010/015 | Batch 000/157 | Cost: 0.5310 Epoch: 010/015 | Batch 050/157 | Cost: 0.4470 Epoch: 010/015 | Batch 100/157 | Cost: 0.5479 Epoch: 010/015 | Batch 150/157 | Cost: 0.5513 training accuracy: 75.52% valid accuracy: 70.84% Time elapsed: 1.33 min Epoch: 011/015 | Batch 000/157 | Cost: 0.4262 Epoch: 011/015 | Batch 050/157 | Cost: 0.6005 Epoch: 011/015 | Batch 100/157 | Cost: 0.5208 Epoch: 011/015 | Batch 150/157 | Cost: 0.5247 training accuracy: 75.98% valid accuracy: 70.90% Time elapsed: 1.46 min Epoch: 012/015 | Batch 000/157 | Cost: 0.5223 Epoch: 012/015 | Batch 050/157 | Cost: 0.5503 Epoch: 012/015 | Batch 100/157 | Cost: 0.5315 Epoch: 012/015 | Batch 150/157 | Cost: 0.4270 training accuracy: 77.91% valid accuracy: 72.88% Time elapsed: 1.61 min Epoch: 013/015 | Batch 000/157 | Cost: 0.5056 Epoch: 013/015 | Batch 050/157 | Cost: 0.5154 Epoch: 013/015 | Batch 100/157 | Cost: 0.4632 Epoch: 013/015 | Batch 150/157 | Cost: 0.4700 training accuracy: 78.33% valid accuracy: 73.00% Time elapsed: 1.74 min Epoch: 014/015 | Batch 000/157 | Cost: 0.4585 Epoch: 014/015 | Batch 050/157 | Cost: 0.5244 Epoch: 014/015 | Batch 100/157 | Cost: 0.4338 Epoch: 014/015 | Batch 150/157 | Cost: 0.4698 training accuracy: 77.28% valid accuracy: 72.38% Time elapsed: 1.88 min Epoch: 015/015 | Batch 000/157 | Cost: 0.5293 Epoch: 015/015 | Batch 050/157 | Cost: 0.4619 Epoch: 015/015 | Batch 100/157 | Cost: 0.4165 Epoch: 015/015 | Batch 150/157 | Cost: 0.4715 training accuracy: 79.31% valid accuracy: 73.72% Time elapsed: 2.01 min Total Training Time: 2.01 min Test accuracy: 73.94%
import spacy
nlp = spacy.load('en')
def predict_sentiment(model, sentence):
# based on:
# https://github.com/bentrevett/pytorch-sentiment-analysis/blob/
# master/2%20-%20Upgraded%20Sentiment%20Analysis.ipynb
model.eval()
tokenized = [tok.text for tok in nlp.tokenizer(sentence)]
indexed = [TEXT.vocab.stoi[t] for t in tokenized]
length = [len(indexed)]
tensor = torch.LongTensor(indexed).to(DEVICE)
tensor = tensor.unsqueeze(1)
length_tensor = torch.LongTensor(length)
prediction = torch.sigmoid(model(tensor, length_tensor))
return prediction.item()
print('Probability positive:')
predict_sentiment(model, "I really love this movie. This movie is so great!")
Probability positive:
0.7535440325737