Sequence-to-Sequence Model with Attention

Seq2seq models are used for applications such as machine translation and image caption generation.

We will build a seq2seq model with attention in PyTorch for translating English to French.

In [0]:
import itertools
from collections import Counter
from functools import partial
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from google_drive_downloader import GoogleDriveDownloader as gdd
from nltk import wordpunct_tokenize
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from tqdm import tqdm_notebook, tqdm

In order to perform deep learning on a GPU (so that everything runs super quick!), CUDA has to be installed and configured. Fortunately, Google Colab already has this set up, but if you want to try this on your own GPU, you can install CUDA from here. Make sure you also install cuDNN for optimized performance.

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
Out[2]:
device(type='cuda')

Download the data

We will download a dataset of English-to-French translations from a public Google Drive folder.

In [0]:
def tokenize(text):
    """Turn text into discrete tokens.

    Remove tokens that are not words.
    """
    text = text.lower()
    tokens = wordpunct_tokenize(text)

    # Only keep words
    tokens = [token for token in tokens
              if all(char.isalpha() for char in token)]

    return tokens


class EnglishFrenchTranslations(Dataset):
    def __init__(self, path, max_vocab):
        self.max_vocab = max_vocab
        
        # Extra tokens to add
        self.padding_token = '<PAD>'
        self.start_of_sequence_token = '<SOS>'
        self.end_of_sequence_token = '<EOS>'
        self.unknown_word_token = '<UNK>'
        
        # Helper function
        self.flatten = lambda x: [sublst for lst in x for sublst in lst]
        
        # Load the data into a DataFrame
        df = pd.read_csv(path, names=['english', 'french'], sep='\t')
        
        # Tokenize inputs (English) and targets (French)
        self.tokenize_df(df)

        # To reduce computational complexity, replace rare words with <UNK>
        self.replace_rare_tokens(df)
        
        # Prepare variables with mappings of tokens to indices
        self.create_token2idx(df)
        
        # Remove sequences with mostly <UNK>
        df = self.remove_mostly_unk(df)
        
        # Every sequence (input and target) should start with <SOS>
        # and end with <EOS>
        self.add_start_and_end_to_tokens(df)
        
        # Convert tokens to indices
        self.tokens_to_indices(df)
        
    def __getitem__(self, idx):
        """Return example at index idx."""
        return self.indices_pairs[idx][0], self.indices_pairs[idx][1]
    
    def tokenize_df(self, df):
        """Turn inputs and targets into tokens."""
        df['tokens_inputs'] = df.english.apply(tokenize)
        df['tokens_targets'] = df.french.apply(tokenize)
        
    def replace_rare_tokens(self, df):
        """Replace rare tokens with <UNK>."""
        common_tokens_inputs = self.get_most_common_tokens(
            df.tokens_inputs.tolist(),
        )
        common_tokens_targets = self.get_most_common_tokens(
            df.tokens_targets.tolist(),
        )
        
        df.loc[:, 'tokens_inputs'] = df.tokens_inputs.apply(
            lambda tokens: [token if token in common_tokens_inputs 
                            else self.unknown_word_token for token in tokens]
        )
        df.loc[:, 'tokens_targets'] = df.tokens_targets.apply(
            lambda tokens: [token if token in common_tokens_targets
                            else self.unknown_word_token for token in tokens]
        )

    def get_most_common_tokens(self, tokens_series):
        """Return the max_vocab most common tokens."""
        all_tokens = self.flatten(tokens_series)
        # Substract 4 for <PAD>, <SOS>, <EOS>, and <UNK>
        common_tokens = set(list(zip(*Counter(all_tokens).most_common(
            self.max_vocab - 4)))[0])
        return common_tokens

    def remove_mostly_unk(self, df, threshold=0.99):
        """Remove sequences with mostly <UNK>."""
        calculate_ratio = (
            lambda tokens: sum(1 for token in tokens if token != '<UNK>')
            / len(tokens) > threshold
        )
        df = df[df.tokens_inputs.apply(calculate_ratio)]
        df = df[df.tokens_targets.apply(calculate_ratio)]
        return df
        
    def create_token2idx(self, df):
        """Create variables with mappings from tokens to indices."""
        unique_tokens_inputs = set(self.flatten(df.tokens_inputs))
        unique_tokens_targets = set(self.flatten(df.tokens_targets))
        
        for token in reversed([
            self.padding_token,
            self.start_of_sequence_token,
            self.end_of_sequence_token,
            self.unknown_word_token,
        ]):
            if token in unique_tokens_inputs:
                unique_tokens_inputs.remove(token)
            if token in unique_tokens_targets:
                unique_tokens_targets.remove(token)
                
        unique_tokens_inputs = sorted(list(unique_tokens_inputs))
        unique_tokens_targets = sorted(list(unique_tokens_targets))

        # Add <PAD>, <SOS>, <EOS>, and <UNK> tokens
        for token in reversed([
            self.padding_token,
            self.start_of_sequence_token,
            self.end_of_sequence_token,
            self.unknown_word_token,
        ]):
            
            unique_tokens_inputs = [token] + unique_tokens_inputs
            unique_tokens_targets = [token] + unique_tokens_targets
            
        self.token2idx_inputs = {token: idx for idx, token
                                 in enumerate(unique_tokens_inputs)}
        self.idx2token_inputs = {idx: token for token, idx
                                 in self.token2idx_inputs.items()}
        
        self.token2idx_targets = {token: idx for idx, token
                                  in enumerate(unique_tokens_targets)}
        self.idx2token_targets = {idx: token for token, idx
                                  in self.token2idx_targets.items()}
        
    def add_start_and_end_to_tokens(self, df):
        """Add <SOS> and <EOS> tokens to the end of every input and output."""
        df.loc[:, 'tokens_inputs'] = (
            [self.start_of_sequence_token]
            + df.tokens_inputs
            + [self.end_of_sequence_token]
        )
        df.loc[:, 'tokens_targets'] = (
            [self.start_of_sequence_token]
            + df.tokens_targets
            + [self.end_of_sequence_token]
        )
        
    def tokens_to_indices(self, df):
        """Convert tokens to indices."""
        df['indices_inputs'] = df.tokens_inputs.apply(
            lambda tokens: [self.token2idx_inputs[token] for token in tokens])
        df['indices_targets'] = df.tokens_targets.apply(
            lambda tokens: [self.token2idx_targets[token] for token in tokens])
             
        self.indices_pairs = list(zip(df.indices_inputs, df.indices_targets))
        
    def __len__(self):
        return len(self.indices_pairs)
In [4]:
DATA_PATH = 'data/english_to_french.txt'
if not Path(DATA_PATH).is_file():
    gdd.download_file_from_google_drive(
        file_id='1Jf7QoW2NK6_ayEXZji6DAXDSIRMvapm3',
        dest_path=DATA_PATH,
    )
Downloading 1Jf7QoW2NK6_ayEXZji6DAXDSIRMvapm3 into data/english_to_french.txt... Done.
In [5]:
dataset = EnglishFrenchTranslations(DATA_PATH, max_vocab=1000)
len(dataset)
Out[5]:
40288
In [0]:
train_size = int(0.999 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

Create data generators using DataLoader

In [0]:
def collate(batch):
    inputs = [torch.LongTensor(item[0]) for item in batch]
    targets = [torch.LongTensor(item[1]) for item in batch]
    
    # Pad sequencse so that they are all the same length (within one minibatch)
    padded_inputs = pad_sequence(inputs, padding_value=dataset.token2idx_targets[dataset.padding_token], batch_first=True)
    padded_targets = pad_sequence(targets, padding_value=dataset.token2idx_targets[dataset.padding_token], batch_first=True)
    
    # Sort by length for CUDA optimizations
    lengths = torch.LongTensor([len(x) for x in inputs])
    lengths, permutation = lengths.sort(dim=0, descending=True)

    return padded_inputs[permutation].to(device), padded_targets[permutation].to(device), lengths.to(device)


batch_size = 512
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collate)

Seq2Seq with Attention

source

Define the Encoder

In [0]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, batch_size):
        super(Encoder, self).__init__()
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)
        self.gru = nn.GRU(
            self.embedding_dim,
            self.hidden_size,
            batch_first=True,
        )
        
    def forward(self, inputs, lengths):
        self.batch_size = inputs.size(0)
        
        # Turn input indices into distributed embeddings
        x = self.embedding(inputs)

        # Remove padding for more efficient RNN application
        x = pack_padded_sequence(x, lengths, batch_first=True)
    
        # Apply RNN to get hidden state at all timesteps (output)
        # and hidden state of last output (self.hidden)
        output, self.hidden = self.gru(x, self.init_hidden())
        
        # Pad the sequences like they were before
        output, _ = pad_packed_sequence(output)
        
        return output, self.hidden

    def init_hidden(self):
        # Randomly initialize the weights of the RNN
        return torch.randn(1, self.batch_size, self.hidden_size).to(device)

Define the Decoder

In [0]:
class Decoder(nn.Module):
    def __init__(
        self, 
        vocab_size,
        embedding_dim, 
        decoder_hidden_size,
        encoder_hidden_size, 
        batch_size,
    ):
        super(Decoder, self).__init__()
        self.batch_size = batch_size
        self.encoder_hidden_size = encoder_hidden_size
        self.decoder_hidden_size = decoder_hidden_size
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)
        self.gru = nn.GRU(
            self.embedding_dim + self.encoder_hidden_size, 
            self.decoder_hidden_size,
            batch_first=True,
        )
        self.fc = nn.Linear(self.encoder_hidden_size, self.vocab_size)
        
        # Attention weights
        self.W1 = nn.Linear(self.encoder_hidden_size, self.decoder_hidden_size)
        self.W2 = nn.Linear(self.encoder_hidden_size, self.decoder_hidden_size)
        self.V = nn.Linear(self.encoder_hidden_size, 1)
    
    def forward(self, targets, hidden, encoder_output):
        self.batch_size = inputs.size(0)
        
        # Switch the dimensions of sequence_length and batch_size
        encoder_output = encoder_output.permute(1, 0, 2)

        # Add an extra axis for a time dimension
        hidden_with_time_axis = hidden.permute(1, 0, 2)
        
        # Attention score (Bahdanaus)
        score = torch.tanh(self.W1(encoder_output) + self.W2(hidden_with_time_axis))

        # Attention weights
        attention_weights = torch.softmax(self.V(score), dim=1)
        
        # Find the context vectors
        context_vector = attention_weights * encoder_output
        context_vector = torch.sum(context_vector, dim=1)
        
        # Turn target indices into distributed embeddings
        x = self.embedding(targets)
        
        # Add the context representation to the target embeddings
        x = torch.cat((context_vector.unsqueeze(1), x), -1)
        
        # Apply the RNN
        output, state = self.gru(x, self.init_hidden())
        
        # Reshape the hidden states (output)
        output = output.view(-1, output.size(2))
        
        # Apply a linear layer
        x = self.fc(output)
        
        return x, state, attention_weights
    
    def init_hidden(self):
        # Randomly initialize the weights of the RNN
        return torch.randn(1, self.batch_size, self.decoder_hidden_size).to(device)

Define a model that has both an Encoder and Decoder

In [0]:
criterion = nn.CrossEntropyLoss()

def loss_function(real, pred):
    """Calculate how wrong the model is."""
    # Use mask to only consider non-zero inputs in the loss
    mask = real.ge(1).float().to(device)
    
    loss_ = criterion(pred, real) * mask 
    return torch.mean(loss_)


class EncoderDecoder(nn.Module):
    def __init__(self, inputs_vocab_size, targets_vocab_size, hidden_size,
                 embedding_dim, batch_size, targets_start_idx, targets_stop_idx):
        super(EncoderDecoder, self).__init__()
        self.batch_size = batch_size
        self.targets_start_idx = targets_start_idx
        self.targets_stop_idx = targets_stop_idx
        
        self.encoder = Encoder(inputs_vocab_size, embedding_dim,
                               hidden_size, batch_size).to(device)
        
        self.decoder = Decoder(targets_vocab_size, embedding_dim,
                               hidden_size, hidden_size, batch_size).to(device)
        
    def predict(self, inputs, lengths):
        self.batch_size = inputs.size(0)
        
        encoder_output, encoder_hidden = self.encoder(
            inputs.to(device),
            lengths,
        )
        decoder_hidden = encoder_hidden
    
        # Initialize the input of the decoder to be <SOS>
        decoder_input = torch.LongTensor(
            [[self.targets_start_idx]] * self.batch_size,
        )
        
        # Output predictions instead of loss
        output = []
        for _ in range(20):
            predictions, decoder_hidden, _ = self.decoder(
                decoder_input.to(device), 
                decoder_hidden.to(device),
                encoder_output.to(device),
            )
            prediction = torch.multinomial(F.softmax(predictions, dim=1), 1)
            decoder_input = prediction

            prediction = prediction.item()
            output.append(prediction)

            if prediction == self.targets_stop_idx:
                return output

        return output

    def forward(self, inputs, targets, lengths):
        self.batch_size = inputs.size(0)
        
        encoder_output, encoder_hidden = self.encoder(
            inputs.to(device),
            lengths,
        )
        decoder_hidden = encoder_hidden
        
        # Initialize the input of the decoder to be <SOS>
        decoder_input = torch.LongTensor(
            [[self.targets_start_idx]] * self.batch_size,
        )
                
        # Use teacher forcing to train the model. Instead of feeding the model's
        # own predictions to itself, feed the target token at every timestep.
        # This leads to faster convergence
        loss = 0
        for timestep in range(1, targets.size(1)):
            predictions, decoder_hidden, _ = self.decoder(
                decoder_input.to(device), 
                decoder_hidden.to(device),
                encoder_output.to(device),
            )
            decoder_input = targets[:, timestep].unsqueeze(1)
            
            loss += loss_function(targets[:, timestep], predictions)
            
        return loss / targets.size(1)
In [13]:
model = EncoderDecoder(
    inputs_vocab_size=len(dataset.token2idx_inputs),
    targets_vocab_size=len(dataset.token2idx_targets),
    hidden_size=256,
    embedding_dim=100, 
    batch_size=batch_size, 
    targets_start_idx=dataset.token2idx_targets[dataset.start_of_sequence_token],
    targets_stop_idx=dataset.token2idx_targets[dataset.end_of_sequence_token],
).to(device)

optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=0.001)

# Training loop
model.train()
for epoch in range(10):
    total_loss = total = 0
    progress_bar = tqdm_notebook(train_loader, desc='Training', leave=False)
    for inputs, targets, lengths in progress_bar:
        # Clean old gradients
        optimizer.zero_grad()

        # Forwards pass
        loss = model(inputs, targets, lengths)

        # Perform gradient descent, backwards pass
        loss.backward()

        # Take a step in the right direction
        optimizer.step()

        # Record metrics
        total_loss += loss.item()
        total += targets.size(1)

    train_loss = total_loss / total
    
    tqdm.write(f'epoch #{epoch + 1:3d}\ttrain_loss: {train_loss:.2e}\n')
epoch #  1	train_loss: 8.02e-02

epoch #  2	train_loss: 5.60e-02

epoch #  3	train_loss: 4.39e-02

epoch #  4	train_loss: 3.56e-02

epoch #  5	train_loss: 2.92e-02

epoch #  6	train_loss: 2.45e-02

epoch #  7	train_loss: 2.16e-02

epoch #  8	train_loss: 1.93e-02

epoch #  9	train_loss: 1.76e-02

epoch # 10	train_loss: 1.63e-02

In [14]:
model.eval()
total_loss = total = 0
with torch.no_grad():
    for inputs, _, lengths in test_loader:
        print('>', ' '.join([
            dataset.idx2token_inputs[idx]
            for idx in inputs.cpu()[0].numpy()[1:-1]
        ]))

        # Forwards pass
        outputs = model.predict(inputs, lengths)
        print(' '.join([
            dataset.idx2token_targets[idx]
            for idx in outputs[:-1]
        ]))
        
        print()
> i was really proud of that
j étais vraiment eu à cela

> why don t you trust me
pourquoi ne me fais pas quoi confiance

> we don t know the answer yet
nous ne sais pas les choses il y

> you like it don t you
vous l as pas

> tom s room was very clean
tom est cette pièce était très chaud

> what a night
c est beaucoup

> i know how much you love tom
je sais combien tu aimes tom

> this book s new
quel livre est mort

> he did not know where to go
il ne connaît pas en aller

> can you tell me anything about what s going to happen here today
te te me dire à ce soit se passe

> it s really very good
c est très bon

> shut off the water
la guerre

> she needs you
elle avait besoin de toi

> the mistakes are mine
le faisons tous deux fois

> i remember that speech
je me souviens ça peut amis

> i wonder if he s really sick
je me demande s il vraiment malade

> i m sure things will work out
je suis sûre de choses peuvent faire

> how far is it from here
ça fonctionne t il ici

> i knew this would be hard for you
je savais ça serait difficile à toi

> we had a little problem
nous l boulot un petit problème

> my sister s going to kill me
ma sœur se passer travailler tu

> i want to talk to you about last night
je veux tu parler de nuit dernière

> you should not sleep
vous ne peux pas dormir

> what have you done
que tu as fait

> i need to know
je dois connais

> remember that we are all in the same boat
rappelle de nous sommes tous chat n en chose

> don t get any ideas
ne t en nourriture à dix idées

> nobody does that
personne n est ce là

> what are you doing today
que tu es aujourd hui

> can i share something important with you
puis je mettre de dire quelque chose d y penser à toi d accord

> you ve got to wait
il faut qu attendre

> she s still young
elle a toujours jeune

> i m sure tom doesn t think that
je suis sûr de ne pense pas

> i m going to study french next year
je vais mettre le prochain dernier prochain

> i want that more than anything
je veux cela

> this is kind of expensive
c était l a mangé

> now leave us
voulez nous

> i wish i could remember his name
c serais le pouvais mieux à son nom

> i know some of these girls
je sais de ces livres

> she promised not to go out alone
elle a homme de se seule pas seul

> what can i tell you
que puis je te dire

In [0]: