In [2]:
%matplotlib inline
import os
import re
import json
import glob
import time
import random
import datetime
from collections import Counter

import torch
from torch.utils.data import DataLoader, Dataset
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
from sklearn.model_selection import train_test_split
from runtimestamp.runtimestamp import runtimestamp

runtimestamp()
Updated 2018-05-10 10:49:34.862764
By yvan
Using Python 3.6.5
On Linux-4.13.0-36-generic-x86_64-with-debian-stretch-sid
In [3]:
torch.__version__
Out[3]:
'0.4.0'

Data Parsing and Preprocessing

In [40]:
# Text-related global variables
max_seq_len = 30
min_word_freq = 20

# GPU variables
use_gpu = torch.cuda.is_available()
device_num = 1
device = torch.device(f"cuda:{device_num}" if use_gpu else "cpu")

# File-writing variables
today = datetime.datetime.now().strftime('%Y-%m-%d')
train_file = '/mnt/hdd2/leon_data/books/3body/en/all_three_train.csv'
valid_file = '/mnt/hdd2/leon_data/books/3body/en/all_three_valid.csv'
test_file = '/mnt/hdd2/leon_data/books/3body/en/all_three_test.csv'
model_dir = '/mnt/hdd2/leon_data/books/models/{}/'.format(today)
file_model = os.path.join(model_dir, '3body_LM__{}.json')
file_wv = os.path.join(model_dir, '3body_LM__wv__{}.txt')
os.makedirs(model_dir, exist_ok=True)
In [41]:
class IndexVectorizer:
    """
    Transforms a Corpus into lists of word indices.
    """
    def __init__(self, max_words=None, min_frequency=None, start_end_tokens=False, maxlen=None):
        self.vocabulary = None
        self.vocabulary_size = 0
        self.word2idx = dict()
        self.idx2word = dict()
        self.max_words = max_words
        self.min_frequency = min_frequency
        self.start_end_tokens = start_end_tokens
        self.maxlen = maxlen

    def _find_max_document_length(self, corpus):
        self.maxlen = max(len(document) for document in corpus)
        if self.start_end_tokens:
            self.maxlen += 2

    def _build_vocabulary(self, corpus):
        vocabulary = Counter(word for document in corpus for word in document)
        if self.max_words:
            vocabulary = {word: freq for word,
                          freq in vocabulary.most_common(self.max_words)}
        if self.min_frequency:
            vocabulary = {word: freq for word, freq in vocabulary.items()
                          if freq >= self.min_frequency}
        self.vocabulary = vocabulary
        self.vocabulary_size = len(vocabulary) + 2  # padding and unk tokens
        if self.start_end_tokens:
            self.vocabulary_size += 2

    def _build_word_index(self):
        self.word2idx['<PAD>'] = 0
        self.word2idx['<UNK>'] = 1

        if self.start_end_tokens:
            self.word2idx['<START>'] = 2
            self.word2idx['<END>'] = 3

        offset = len(self.word2idx)
        for idx, word in enumerate(self.vocabulary):
            self.word2idx[word] = idx + offset
        self.idx2word = {idx: word for word, idx in self.word2idx.items()}

    def fit(self, corpus):
        if not self.maxlen:
            self._find_max_document_length(corpus)
        self._build_vocabulary(corpus)
        self._build_word_index()

    def pad_document_vector(self, vector):
        padding = self.maxlen - len(vector)
        vector.extend([self.word2idx['<PAD>']] * padding)
        return vector

    def add_start_end(self, vector):
        vector.append(self.word2idx['<END>'])
        return [self.word2idx['<START>']] + vector

    def transform_document(self, document, offset=0):
        """
        Vectorize a single document
        """
        vector = [self.word2idx.get(word, self.word2idx['<UNK>']) 
                  for word in document]
        if len(vector) > self.maxlen:
            vector = vector[:self.maxlen]
        if self.start_end_tokens:
            vector = self.add_start_end(vector)
        vector = vector[offset:self.maxlen]
        
        return self.pad_document_vector(vector)

    def transform(self, corpus):
        """
        Vectorizes a corpus in the form of a list of lists.
        A corpus is a list of documents and a document is a list of words.
        """
        return [self.transform_document(document) for document in corpus]
    
    
class ThreeBodyDataset(Dataset):
    def __init__(self, path, vectorizer, tokenizer=None, stopwords=None):
        self.corpus = pd.read_csv(path)
        self.tokenizer = tokenizer
        self.vectorizer = vectorizer
        self.stopwords = stopwords
        self._tokenize_corpus()
        if self.stopwords: self._remove_stopwords() 
        self._vectorize_corpus()

    def _remove_stopwords(self):
        stopfilter = lambda doc: [word for word in doc if word not in self.stopwords]
        self.corpus['tokens'] = self.corpus['tokens'].apply(stopfilter)

    def _tokenize_corpus(self):
        if self.tokenizer:
            self.corpus['tokens'] = self.corpus['sentences'].apply(self.tokenizer)
        else:
            self.corpus['tokens'] = self.corpus['sentences'].apply(lambda x: x.lower().split())

    def _vectorize_corpus(self):
        if not self.vectorizer.vocabulary:
            self.vectorizer.fit(self.corpus['tokens'])
        self.corpus['vectors'] = self.corpus['tokens'].apply(self.vectorizer.transform_document)
        self.corpus['target'] = self.corpus['tokens'].apply(self.vectorizer.transform_document, offset=1)

    def __getitem__(self, index):
        sentence = self.corpus['vectors'].iloc[index]
        target = self.corpus['target'].iloc[index]
        return torch.LongTensor(sentence), torch.LongTensor(target)

    def __len__(self):
        return len(self.corpus)
    
def simple_tokenizer(text):
    return text.lower().split()
In [42]:
vectorizer = IndexVectorizer(max_words=None, min_frequency=min_word_freq, 
                             start_end_tokens=True, maxlen=max_seq_len)

training_set = ThreeBodyDataset(train_file, vectorizer, simple_tokenizer)
test_set = ThreeBodyDataset(test_file, vectorizer, simple_tokenizer)
validation_set = ThreeBodyDataset(valid_file, vectorizer, simple_tokenizer)
In [43]:
len(training_set), len(validation_set), len(test_set)
Out[43]:
(10387, 1732, 1731)
In [44]:
print("Vocab size: {}".format(vectorizer.vocabulary_size))
Vocab size: 1194
In [45]:
training_set.corpus.iloc[0]
Out[45]:
sentences    Purely in terms of command , humanity might ne...
tokens       [purely, in, terms, of, command, ,, humanity, ...
vectors      [2, 1, 4, 1, 5, 6, 7, 8, 9, 10, 11, 12, 1, 13,...
target       [1, 4, 1, 5, 6, 7, 8, 9, 10, 11, 12, 1, 13, 1,...
Name: 0, dtype: object

We'll divide our dataset into a training set and a test set. We'll also create some functions for loading the data into batches.

In [46]:
def log(msg):
    print(msg)
    
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class History(object):
    """Records Loss and Validation Loss"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.loss = dict()
        self.val_loss = dict()
        self.min_loss = 100
    
    def update_min_loss(self, min_loss):
        self.min_loss = min_loss
        
    def update_loss(self, loss):
        epoch = len(self.loss.keys())
        self.loss[epoch] = loss
    
    def update_val_loss(self, val_loss):
        epoch = len(self.val_loss.keys())
        self.val_loss[epoch] = val_loss
        
    def plot(self):
        loss = sorted(self.loss.items())
        x, y = zip(*loss)
        
        if self.val_loss:
            val_loss = sorted(self.val_loss.items())
            x1, y1 = zip(*val_loss)
            plt.plot(x, y, 'C0', label='Loss')
            plt.plot(x1, y1, 'C2', label='Validation Loss')
            plt.legend();
        else:
            plt.plot(x, y, 'C0');
    
def categorical_accuracy(y_true, y_pred):
    y_true = y_true.float()
    _, y_pred = torch.max(y_pred.squeeze(), dim=-1)
    return (y_pred.float() == y_true).float().mean()

def softmax_trick(x):
    logits_exp = torch.exp(x - torch.max(x))
    weights = torch.div(logits_exp, logits_exp.sum())
    return weights

def save_state_dict(model, filepath):
    '''Saves the model weights as a dictionary'''
    model_dict = model.state_dict()
    torch.save(model_dict, filepath)
    return model_dict

LSTM Language Model

Here we'll define a recurrent language model.

In [47]:
class RNNLM(nn.Module):
    def __init__(self, vocab_size, seq_len, embedding_size, 
                 hidden_size, batch_size, 
                 dropout=.5, num_layers=1, tie_weights=False, 
                 bidirectional=False, word2idx={}, log_softmax=False):
       
        super(RNNLM, self).__init__()
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.tie_weights = tie_weights
        self.num_layers = num_layers
        self.num_directions = 1 if not bidirectional else 2
        self.word2idx = word2idx
        self.idx2word = {v:k for k,v in word2idx.items()}
        
        # Model Pieces
        self.dropout = nn.Dropout(p = dropout)
        self.log_softmax = nn.LogSoftmax(dim = 1) if log_softmax else None
        
        # Model Layers
        self.encoder = nn.Embedding(vocab_size, embedding_size, 
                                    padding_idx = word2idx.get('<PAD>', 1))
        
        self.lstm1 = nn.LSTM(embedding_size, hidden_size, 
                             num_layers = 1, 
                             bidirectional = bidirectional,
                             batch_first = True)
        
        self.lstm2 = nn.LSTM(hidden_size * self.num_directions, hidden_size, 
                             num_layers = 1, 
                             bidirectional = bidirectional,
                             batch_first = True)
        
        self.decoder = nn.Linear(hidden_size * self.num_directions, vocab_size)

        # tie enc/dec weights
        if self.tie_weights:
            if hidden_size != embedding_size:
                raise ValueError('When using the `tied` flag, hidden_size'
                                 'must be equal to embedding_dim')
            self.decoder.weight = self.encoder.weight
            
        self.init_weights()

        
    def init_hidden(self, bsz=None):
        '''
        For the nn.LSTM.
        Defaults to the batchsize stored in the class params, but can take in an argument
        in the case of sampling.
        '''
        if bsz == None: 
            bsz = self.batch_size
        h0 = torch.zeros(self.num_layers * self.num_directions, 
                         bsz, self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers * self.num_directions, 
                         bsz, self.hidden_size ).to(device)
        return (h0, c0)
    
    
    def init_weights(self):
        initrange = 0.1
        em_layer = [self.encoder]
        lin_layers = [self.decoder]
        for layer in lin_layers + em_layer:
            layer.weight.data.uniform_(-initrange, initrange)
            if layer in lin_layers:
                layer.bias.data.fill_(0)
    
    
    def sample(self, x_start):
        '''
        Generates a sequence of text given a starting word ix and hidden state.
        '''
        with torch.no_grad():
            indices = [x_start]
            for i in range(self.seq_len):
                # create inputs
                x_input = torch.LongTensor(indices).to(device)
                x_embs = self.encoder(x_input.view(1, -1))

                # send input through the rnn
                output, hidden = self.lstm1(x_embs)
                output, hidden = self.lstm2(output, hidden)

                # format the last word of the rnn so we can softmax it.
                one_dim_last_word = output.squeeze()[-1] if i > 0 else output.squeeze()
                fwd = one_dim_last_word[ : self.hidden_size ]
                bck = one_dim_last_word[ self.hidden_size : ]

                # pick a word from the disto
                word_weights = softmax_trick(fwd)
                word_idx = torch.multinomial(word_weights, num_samples=1).squeeze().item()
                indices.append(word_idx)

        return indices
    
    
    def forward(self, x, hidden, log_softmax=False):
        '''
        Iterates through the input, encodes it.
        Each embedding is sent through the step function.
        Dropout the last hidden layer and decode to a logit
        x.size() #(bsz, seq_len)
        
        logit.size # (bsz, seq_len, vocab_size)
        equivalent to (output.size(0), output.size(1), logit.size(1)
        '''
        x_emb = self.encoder(x)
        
        output, hidden = self.lstm1(x_emb, hidden)
        output, hidden = self.lstm2(self.dropout(output), hidden)
        
        logit = self.decoder(self.dropout(output))
        if self.log_softmax:
            logit = self.log_softmax(logit)
        logit = logit.view(logit.size(0) * self.seq_len, self.vocab_size)
        
        return logit, hidden
In [48]:
def run_epoch(model, dataset, criterion, optim, batch_size, 
              train=False, shuffle=True):
    '''A wrapper for a training, validation or test run.'''
    model.train() if train else model.eval()
    loss = AverageMeter()
    accuracy = AverageMeter()
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    for X, y in loader:
        model.zero_grad() 
        X = X.squeeze().to(device)
        y = y.squeeze().view(-1).to(device)

        # get a predition    
        hidden = model.init_hidden(X.size(0))
        y_, hidden = model(X, hidden)
        
        # calculate loss and accuracy
        lossy = criterion(y_.squeeze(), y.squeeze())
        accy = categorical_accuracy(y_.squeeze().data, y.squeeze().data)
        
        loss.update(lossy.data.item())
        accuracy.update(accy)
        
        # backprop
        if train:
            lossy.backward()
            optim.step()
    
    return loss.avg, accuracy.avg


def training_epoch(*args, **kwargs):
    '''Training Epoch'''
    return run_epoch(train=True, *args, **kwargs)
    
    
def validation_epoch(*args, **kwargs):
    '''Validation Epoch'''
    return run_epoch(*args, **kwargs)
 
    
def sample_lm(model):
    '''Samples a language model and returns generated words'''
    start_idx = model.word2idx['<START>']
    indices = model.sample(start_idx)
    words = [model.idx2word[index] for index in indices]
    
    return words


def training_loop(batch_size, num_epochs, display_freq, model, criterion, 
                  optim, training_set, validation_set=None, 
                  best_model_path='model', history=None):
    '''Training iteration.'''
    if not history:
        history = History()
    
    try: 
        for epoch in tqdm(range(num_epochs)):
            loss, accuracy = training_epoch(model, training_set, criterion, optim, batch_size)
            history.update_loss(loss)
            
            if validation_set:
                val_loss, val_accuracy = validation_epoch(model, validation_set, 
                                                          criterion, optim, batch_size)  
                history.update_val_loss(val_loss)
                if val_loss < history.min_loss:
                    save_state_dict(model, best_model_path)
                    history.update_min_loss(val_loss)
            else:
                if loss < history.min_loss:
                    save_state_dict(model, best_model_path)
                    history.update_min_loss(loss)
                
            if epoch % display_freq == 0:
                # display stats
                if validation_set:
                    log("Epoch: {:04d}; Loss: {:.4f}; Val-Loss {:.4f}; "
                        "Perplexity {:.4f}; Val-Perplexity {:.4f}".format(
                            epoch, loss, val_loss, np.exp(loss), np.exp(val_loss)))
                else:
                    log("Epoch: {:04d}; Loss: {:.4f}; Perplexity {:.4f};".format(
                            epoch, loss, np.exp(loss)))
                
                # sample from the language model
                words = sample_lm(model)
                log("Sample: {}".format(' '.join(words)))
                time.sleep(1)
        
        log('-' * 89)
        log("Training complete")
        log("Lowest loss: {:.4f}".format(history.min_loss))
        
        return  history
        
    except KeyboardInterrupt:
        log('-' * 89)
        log('Exiting from training early')
        log("Lowest loss: {:.4f}".format(history.min_loss))

        return history

Training the Model

We'll declare hyperparameters here, instantiate our model, create a training set data batcher, and train our model.

In [52]:
# Set Seed
if use_gpu: torch.cuda.manual_seed(303)
else: torch.manual_seed(303)

# set up Files to save stuff in
runtime = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
file_model = file_model.format(runtime)
file_wv = file_wv.format(runtime)
    
# Model Hyper Parameters 
hidden_dim = 100
embedding_dim = 200
batch_size = 512
dropout = 0.2
lstm_layers = 1 # this is useless atm
lstm_bidirection = True

# Training
learning_rate = 1e-4
num_epochs = 300
display_epoch_freq = 10

# Build and initialize the model
lm = RNNLM(vectorizer.vocabulary_size, max_seq_len, embedding_dim, hidden_dim, batch_size, 
           dropout = dropout, 
           tie_weights = False, 
           num_layers = lstm_layers, 
           bidirectional = lstm_bidirection, 
           word2idx = vectorizer.word2idx,
           log_softmax = True)

if use_gpu:
    lm = lm.to(device)
lm.init_weights()

# Loss and Optimizer
loss = nn.NLLLoss()
optimizer = torch.optim.Adam(lm.parameters(), lr=learning_rate)

# Train the model
history = training_loop(batch_size, num_epochs, display_epoch_freq, 
                        lm, loss, optimizer, training_set, validation_set, 
                        best_model_path=file_model)
Epoch: 0000; Loss: 3.3905; Val-Loss 3.3774; Perplexity 29.6801; Val-Perplexity 29.2938
Sample: <START> day about away they quickly morning lit away these image use use day in three quickly - . command <START> by see but transmission this her this <PAD> something transmission

Exception in thread Thread-7:
Traceback (most recent call last):
  File "/home/yvan/anaconda3/envs/leon_pytorch/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/home/yvan/anaconda3/envs/leon_pytorch/lib/python3.6/site-packages/tqdm/_monitor.py", line 62, in run
    for instance in self.tqdm_cls._instances:
  File "/home/yvan/anaconda3/envs/leon_pytorch/lib/python3.6/_weakrefset.py", line 60, in __iter__
    for itemref in self.data:
RuntimeError: Set changed size during iteration

Epoch: 0010; Loss: 2.7412; Val-Loss 2.7036; Perplexity 15.5052; Val-Perplexity 14.9334
Sample: <START> soon cigarette star his ji luo an about history had <PAD> got might but command in this children walking was children an when and quickly but with centuries star glanced
Epoch: 0020; Loss: 2.6513; Val-Loss 2.6180; Perplexity 14.1722; Val-Perplexity 13.7078
Sample: <START> space , . her see transmission time walking several <PAD> his it edge speed seemed not walking was might out display it , something lit and can morning up soon
Epoch: 0030; Loss: 2.3643; Val-Loss 2.3149; Perplexity 10.6371; Val-Perplexity 10.1237
Sample: <START> <END> was ! they when turned five universe this these <START> a quickly seemed display lit of but several never wang he held its want about by use something red
Epoch: 0040; Loss: 2.0659; Val-Loss 2.0162; Perplexity 7.8925; Val-Perplexity 7.5095
Sample: <START> as wang humanity already see time look in walking can by five up <END> red with turned space away you transmission with lit his these got a three in want
Epoch: 0050; Loss: 1.8465; Val-Loss 1.7997; Perplexity 6.3377; Val-Perplexity 6.0475
Sample: <START> might star morning it as about noticed morning on had history started might command by see with three speed command never never out in image , was at was turned
Epoch: 0060; Loss: 1.6731; Val-Loss 1.6281; Perplexity 5.3288; Val-Perplexity 5.0940
Sample: <START> that never her direction soon was ! already but only past universe in something in universe glanced space space space his in about seemed five in in you a two
Epoch: 0070; Loss: 1.5413; Val-Loss 1.5004; Perplexity 4.6705; Val-Perplexity 4.4834
Sample: <START> see to it direction it you never lit might red hands these luo not started transmission : however its up children saw saw held this the lit transmission ! like
Epoch: 0080; Loss: 1.4411; Val-Loss 1.4104; Perplexity 4.2254; Val-Perplexity 4.0975
Sample: <START> out want the ji two but two five ! by got wang morning out might out have look battle . day use away were battle lit were her on had
Epoch: 0090; Loss: 1.3598; Val-Loss 1.3254; Perplexity 3.8956; Val-Perplexity 3.7637
Sample: <START> not hands only transmission was - look however the was lit got like her transmission were half hands of they have , as the look they look however seemed centuries
Epoch: 0100; Loss: 1.2966; Val-Loss 1.2670; Perplexity 3.6568; Val-Perplexity 3.5501
Sample: <START> something hands in hands image hands might past saw it you started star lit glanced about transmission by was lit had got lit as at not image red something trisolaris
Epoch: 0110; Loss: 1.2488; Val-Loss 1.2281; Perplexity 3.4861; Val-Perplexity 3.4148
Sample: <START> , you light five - by noticed lit and <PAD> want three but walking got have transmission these turned with wang out her cigarette seemed two <START> an half transmission
Epoch: 0120; Loss: 1.2097; Val-Loss 1.1952; Perplexity 3.3524; Val-Perplexity 3.3042
Sample: <START> centuries her battle were out speed soon half held cigarette on seemed might glanced morning when transmission two as noticed never five see in turned be that use to you
Epoch: 0130; Loss: 1.1852; Val-Loss 1.1774; Perplexity 3.2712; Val-Perplexity 3.2460
Sample: <START> as can with her these seemed its history noticed <START> direction to past something out display trisolaris light red were was - day got might noticed in something use universe
Epoch: 0140; Loss: 1.1666; Val-Loss 1.1600; Perplexity 3.2112; Val-Perplexity 3.1898
Sample: <START> to got its star seemed cigarette look had something . have had star with not two that <END> trisolaris and light her luo : light direction like an be transmission
Epoch: 0150; Loss: 1.1460; Val-Loss 1.1398; Perplexity 3.1456; Val-Perplexity 3.1260
Sample: <START> look about lit it glanced wang time these like direction <START> they about at seemed his image day out with he of not when ji its cigarette humanity day not
Epoch: 0160; Loss: 1.1350; Val-Loss 1.1276; Perplexity 3.1112; Val-Perplexity 3.0882
Sample: <START> something <END> display started time want display away it held humanity saw luo command at she not saw started morning be several children hands cigarette these past day luo however
Epoch: 0170; Loss: 1.1251; Val-Loss 1.1225; Perplexity 3.0805; Val-Perplexity 3.0726
Sample: <START> was time to in on not lit her - look want his wang only got past however like several however star he seemed had held the and glanced when star
Epoch: 0180; Loss: 1.1154; Val-Loss 1.1086; Perplexity 3.0507; Val-Perplexity 3.0301
Sample: <START> on about not have on several <END> might image glanced hands that turned might universe with held had ! want saw these noticed cigarette cigarette <UNK> they children not these
Epoch: 0190; Loss: 1.1087; Val-Loss 1.1054; Perplexity 3.0303; Val-Perplexity 3.0203
Sample: <START> her not humanity to however speed want a in cigarette in these you these seemed had out its with speed lit - was several were only its ! image soon
Epoch: 0200; Loss: 1.1046; Val-Loss 1.1010; Perplexity 3.0180; Val-Perplexity 3.0073
Sample: <START> her but however direction red : soon five it have morning glanced the <START> a seemed direction this want noticed on away might his <START> its glanced be three speed
Epoch: 0210; Loss: 1.0969; Val-Loss 1.0958; Perplexity 2.9948; Val-Perplexity 2.9915
Sample: <START> direction transmission past saw started however held the see of of a seemed day only as morning space about you transmission look like a have was cigarette you star lit
Epoch: 0220; Loss: 1.0922; Val-Loss 1.0979; Perplexity 2.9809; Val-Perplexity 2.9980
Sample: <START> luo never can luo - turned to have morning something the when cigarette want wang see you an the direction an transmission on its its got already you its edge
Epoch: 0230; Loss: 1.0894; Val-Loss 1.0973; Perplexity 2.9725; Val-Perplexity 2.9961
Sample: <START> three light you this see speed started transmission might he something glanced use with light might an history star as past speed noticed you battle something an by her .
Epoch: 0240; Loss: 1.0848; Val-Loss 1.0963; Perplexity 2.9588; Val-Perplexity 2.9930
Sample: <START> wang see morning seemed centuries by these <PAD> a humanity like seemed morning and were be past started morning ! <PAD> of three cigarette these ! about only held when
Epoch: 0250; Loss: 1.0828; Val-Loss 1.0886; Perplexity 2.9529; Val-Perplexity 2.9702
Sample: <START> never never however . its an never with held on want got when with had an started two her command speed three held noticed its <UNK> see saw speed about
Epoch: 0260; Loss: 1.0798; Val-Loss 1.0921; Perplexity 2.9441; Val-Perplexity 2.9805
Sample: <START> wang at up morning red its these cigarette only <END> held not command five noticed seemed at got his seemed in as battle - something edge and quickly day by
Epoch: 0270; Loss: 1.0771; Val-Loss 1.0810; Perplexity 2.9362; Val-Perplexity 2.9475
Sample: <START> <UNK> only a transmission like cigarette her glanced history have cigarette you had had see walking you her glanced two soon its trisolaris but the three red noticed want not
Epoch: 0280; Loss: 1.0761; Val-Loss 1.0786; Perplexity 2.9333; Val-Perplexity 2.9406
Sample: <START> her her in seemed to transmission had - got image saw see was was might but as speed centuries soon had like had never her never had <PAD> they day
Epoch: 0290; Loss: 1.0753; Val-Loss 1.0849; Perplexity 2.9308; Val-Perplexity 2.9591
Sample: <START> had not you centuries red cigarette history with red hands its red when was trisolaris had held saw the by have command battle had hands speed soon were time five

-----------------------------------------------------------------------------------------
Training complete
Lowest loss: 1.0744
In [53]:
history.plot()
In [54]:
def test_loop(batch_size, model, criterion, optim, test_set,):
    '''Data iterator for the test set'''
    model.eval()
    
    try:
        test_loss, test_accuracy = validation_epoch(model, test_set, 
                                                    criterion, optim, batch_size)
        log('Evaluation Complete')
        log('Test set Loss: {}'.format(test_loss))
        log('Test set Perplexity: {}'.format(np.exp(test_loss)))
        log('Test set Accuracy: {}'.format(test_accuracy))
    
    except KeyboardInterrupt:
        log('-' * 89)
        log('Exiting from testing early')
In [55]:
lm.load_state_dict(torch.load(file_model))
test_loop(256, lm, loss, optimizer, test_set)
Evaluation Complete
Test set Loss: 1.0961862802505493
Test set Perplexity: 2.992730795887753
Test set Accuracy: 0.0

Saving Embeddings Weights

  1. Extrcting vectors into a dictionary (word2vec)
  2. Saving the vectors (word2vec).
In [64]:
def get_elmo_vectors(model):
    model.eval()
    word2vec = {}
    for w, ix in tqdm(model.word2idx.items()):
        ix_var = torch.LongTensor([ix]).to(device)

        emb = model.encoder(ix_var.view(1,-1))
        out1, hid_ = model.lstm1(emb)
        out2, hid_ = model.lstm2(out1, hid_)
        
        emb = emb.view(1,-1).data.cpu().numpy()[0]
        out1 = out1.view(1,-1).data.cpu().numpy()[0]
        out2 = out2.view(1,-1).data.cpu().numpy()[0]

        word2vec[w] = dict(
            embedding = emb,
            hid1 = out1,
            hid2 = out2,
        )
        
    return word2vec
In [65]:
word2vec = get_elmo_vectors(lm)

In [58]:
with open(file_wv, 'w+') as f:
    for w, vec in tqdm(word2vec.items()):
        row = np.concatenate([[w], vec['embedding'], vec['hid1'], vec['hid2']])
        f.write(' '.join(row) + '\n')

In [59]:
' '.join(row)
Out[59]:
'doomsday 0.019144375 -0.011262503 -0.054264653 0.07552316 -0.04677312 -0.043715246 0.029800449 -0.098677255 -0.087699786 0.111737646 0.08108406 0.08127617 0.14914002 -0.17581704 0.010354878 0.12268086 -0.08209066 -0.03954192 -0.009683989 0.060234316 0.079789676 -0.11663191 0.124724835 -0.037095707 0.04702442 -0.029397946 -0.12079265 -0.10139638 0.0095835645 0.113344245 0.08293821 0.100370824 -0.008190885 0.008966734 0.021829829 -0.08489138 0.10120566 -0.010230141 -0.04472376 -0.10047493 -0.0625175 -0.055764224 0.1547271 -0.1139233 -0.057975918 -0.11362845 -0.047995884 -0.028644199 0.031366825 0.09102691 0.0064217034 0.0591397 -0.034809817 -0.030157201 -0.055808205 0.017484395 0.057290915 0.056935456 0.04440339 0.03640528 0.061410222 -0.06576907 -0.066208206 0.009575155 0.09143286 -0.10088797 -0.08211959 0.059557542 -0.035395663 0.024542004 -0.035210326 0.014628755 0.12772882 -0.057329144 -0.0744061 -0.02468062 0.009410101 -0.04309125 -0.084312454 0.0762925 -0.020742945 0.04448962 0.12309611 0.016032873 0.15162481 -0.1314242 0.12193378 0.08225759 0.029561715 -0.0614505 0.006002207 -0.0927667 -0.065466516 -0.06480242 0.0077024833 0.020869521 -0.0758709 0.090688854 0.03706296 0.0017705106 0.043391243 -0.14248165 0.05658515 0.07355731 0.07907174 0.019258726 -0.08251391 -0.061569046 -0.07022654 0.038820278 -0.0878327 -0.041754447 0.014912813 -0.015131595 0.0001365648 -0.00984699 -0.11710193 0.027462937 -0.07352145 0.033495378 0.020122392 0.095105104 0.0015523405 -0.05174806 -0.055938885 0.08180001 -0.115246415 0.039961703 -0.047477294 0.021761438 -0.01935918 0.037725147 -0.045951545 0.059311435 0.003186048 -0.09801708 0.100893565 0.050895005 0.09025452 0.054864142 0.109954745 0.16359329 -0.036519982 0.10660204 -0.070585296 -0.03825811 0.008745374 -0.006380331 0.066654995 0.0025552027 -0.053807035 0.038850956 -0.031916663 0.08099706 -0.010193364 -0.12841539 -0.019125786 -0.034416553 0.056395948 -0.0029892267 0.034501243 0.058722064 -0.043787777 -0.14189883 -0.06835651 -0.011339926 0.071666695 -0.053338896 0.074751765 0.07303199 -0.080664754 -0.0674444 0.021041414 -0.11712903 0.015653785 -0.10301745 0.13191797 0.03336045 0.073667005 0.02220278 0.018015405 -0.04430134 -0.050868444 0.13137431 -0.12607768 -0.04385574 0.10287586 -0.02199097 -0.046265874 -0.046365585 0.03192538 -0.05965421 -0.008514234 0.06323657 0.11607088 0.043476973 0.006071745 -0.03002573 0.079065256 0.10162905 0.123988666 -0.11351591 0.012707224 0.033863965 -0.038223825 0.04719314 -0.04606858 -0.011751718 0.04507148 -0.0048624175 0.042390265 -0.13216378 -0.076352164 0.0075225094 0.061840795 0.2257554 -0.06922411 -0.089413665 -0.022556232 -0.07062854 -0.11332289 -0.007943083 0.10174688 0.016032528 0.076806374 0.016981956 0.071279734 0.00037128598 -0.0088539515 -0.0195572 -0.1138637 -0.037353866 -0.043936227 0.16313241 -0.05448641 0.14080141 0.015675344 -0.09355434 0.033032957 0.13163991 -0.04940434 0.027297204 0.011381093 0.07453263 -0.027582524 0.047562845 -0.17956817 -0.15048222 -0.075844266 0.164489 0.0888622 -0.054217402 -0.11419919 -0.046849623 0.0032962284 0.12490391 0.02095681 0.07558962 -0.0038533907 -0.07554676 0.07734367 -0.08135916 -0.013699127 0.048568472 -0.0627778 0.10186819 -0.08268394 -0.09364053 -0.11943724 -0.047317725 0.02201621 0.045001548 0.09462936 0.048074357 -0.05612918 0.0825079 -0.088830166 -0.03815688 -0.09791671 0.00797545 0.026351817 -0.091244705 -0.011592205 -0.05384624 0.0487457 -0.11188113 -0.05003087 -0.009712308 0.02694161 -0.009335604 0.024024239 -0.1270577 -0.14116967 -0.027769063 0.06540114 0.07800534 -0.08499379 0.076454766 0.09255872 -0.12363913 -0.101437315 0.007894608 0.17152315 0.00042451595 0.024041355 -0.041437272 0.050027844 -0.02760984 0.021333607 -0.049038976 0.03679542 -0.14440495 -0.029170137 -0.0027979224 -0.004392188 0.015992261 -0.021676196 -0.1323027 -0.0015445176 0.21564086 -0.013431105 -0.0061001927 -0.032708876 -0.04209167 0.054932624 -0.01859128 -0.11102446 -0.09928766 0.043740373 0.09106052 0.064958274 -0.06324171 0.1484803 0.21339355 -0.15271601 0.055790193 -0.07014506 0.020499673 0.0046761455 -0.009467079 0.06724552 0.17583278 -0.20661941 0.021186551 0.13845748 -0.019551588 0.1073641 -0.032113653 0.12176157 0.04607929 0.07857493 0.04625448 -0.0724659 -0.13722306 -0.10848629 -0.015127993 -0.052834637 -0.080647975 -0.18753123 0.13316339 -0.18373246 0.044140592 0.030956859 0.082465254 -0.09925594 0.029463645 0.048026074 0.06725748 -0.022498338 0.30080742 -0.04067377 0.18708871 -0.05891599 -0.034532316 -0.02673782 -0.0774335 -0.03408435 -0.07535319 -0.16249557 0.15918423 -0.045513082 -0.026489092 0.063562505 0.0035445625 0.04143385 0.054864183 0.11467244 -0.006415189 -0.017656647 0.02820999 -0.12635364 -0.017223196 -0.020308511 0.09518707 -0.20856085 -0.10820239 -0.118585385 0.051283766 -0.01944336 0.010883282 -0.15552144 -0.026531573 -0.0022704673 -0.07406549 0.15039048 -0.09156772 0.19901429 -0.11889205 -0.027131049 0.13789564 0.21246442 0.085902646 0.030998515 -0.18561736 0.13175236 0.10126359 0.06693361 -0.14071806 0.09460603 -0.20043278 0.08795798 0.1279674 -0.14521612 -0.08201606 -0.07542526 0.045425393 0.25055277 -0.2020613 -0.09616534 -0.12523662 0.059590247 -0.08095756 0.13505952 0.014808842 -0.039542053 0.37339324 -0.042027168 -0.032162536 0.11853116 0.1960998 -0.17258637 -0.044939574 0.05499272 -0.0019374555 0.0815436 -0.06386393 0.15047418 -0.2258666 -0.062035587 -0.03271413 0.046317052 0.15558659 -0.15382297 -0.13178296 -0.14106806 0.054497693 -0.10640273 0.2416711 -0.08302012 -0.14271438 -0.0866469 0.1034047 -0.018578123 -0.14366029 0.08044761 -0.015542477 0.209219 -0.2218885 -0.04830695 0.1617367 0.100623764 0.12184961 -0.062163666 0.19608174 -0.07673878 -0.04012129 -0.22895329 -0.13467328 0.012120137 0.040390603 0.10485775 0.15833047 0.19763793 0.0932549 0.07177312 -0.05330936 -0.025881646 0.01982425 -0.050485928 0.17094877 0.061567124 -0.19522922 -0.11896508 0.067274734 -0.04687178 0.025645025 -0.1969956 -0.07138407 0.00557158 -0.27678615 -0.22383045 -0.27760637 0.13022874 0.10331432 -0.08814418 0.23716603 0.023350634 -0.28934518 0.15164366 -0.08098609 0.13166778 -0.1763806 -0.16157448 0.028121786 0.14172442 0.06670422 0.07928535 0.008279952 -0.0639333 0.17381434 -0.08439902 0.2515398 -0.04096333 -0.073651336 0.21865274 0.18785118 0.22600447 0.16057311 0.11348682 -0.10273971 0.04218584 0.008524662 -0.10813271 0.066350296 -0.09815816 0.12081578 -0.030189963 0.17794281 -0.17870511 0.25373653 -0.037956443 -0.0125516895 0.03886509 -0.08524119 0.28474876 0.09443538 0.28689948 0.25865027 0.16293843 -0.1992543 -0.04438634 -0.15152445 0.10258499 -0.08918362 0.13862842 0.19983067 0.1311691 0.0499539 0.12701514 -0.09394877 -0.19378224 -0.060368024 -0.31211683 -0.18064882 0.09889215 -0.1825807 0.21072602 -0.008927809 -0.093034856 -0.10520393 0.14345038 0.016104385 0.13663946 -0.09954506 -0.086325355 0.087362535 -0.33888233 -0.20934108 0.14313582 0.17845821 -0.25092235 0.10741558 -0.036945153 0.13091739 -0.29757756 -0.091621876 -0.06711745 -0.16859876 0.22402763 -0.21097897 0.13412702 0.12645882 0.13199389 0.07841743 -0.03782335 -0.24159867 -0.23621875 0.10574535 0.15957592 0.15734378'