%matplotlib inline
import os
import re
import json
import glob
import time
import random
import datetime
from collections import Counter
import torch
from torch.utils.data import DataLoader, Dataset
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
from sklearn.model_selection import train_test_split
from runtimestamp.runtimestamp import runtimestamp
runtimestamp()
Updated 2018-05-10 10:49:34.862764 By yvan Using Python 3.6.5 On Linux-4.13.0-36-generic-x86_64-with-debian-stretch-sid
torch.__version__
'0.4.0'
# Text-related global variables
max_seq_len = 30
min_word_freq = 20
# GPU variables
use_gpu = torch.cuda.is_available()
device_num = 1
device = torch.device(f"cuda:{device_num}" if use_gpu else "cpu")
# File-writing variables
today = datetime.datetime.now().strftime('%Y-%m-%d')
train_file = '/mnt/hdd2/leon_data/books/3body/en/all_three_train.csv'
valid_file = '/mnt/hdd2/leon_data/books/3body/en/all_three_valid.csv'
test_file = '/mnt/hdd2/leon_data/books/3body/en/all_three_test.csv'
model_dir = '/mnt/hdd2/leon_data/books/models/{}/'.format(today)
file_model = os.path.join(model_dir, '3body_LM__{}.json')
file_wv = os.path.join(model_dir, '3body_LM__wv__{}.txt')
os.makedirs(model_dir, exist_ok=True)
class IndexVectorizer:
"""
Transforms a Corpus into lists of word indices.
"""
def __init__(self, max_words=None, min_frequency=None, start_end_tokens=False, maxlen=None):
self.vocabulary = None
self.vocabulary_size = 0
self.word2idx = dict()
self.idx2word = dict()
self.max_words = max_words
self.min_frequency = min_frequency
self.start_end_tokens = start_end_tokens
self.maxlen = maxlen
def _find_max_document_length(self, corpus):
self.maxlen = max(len(document) for document in corpus)
if self.start_end_tokens:
self.maxlen += 2
def _build_vocabulary(self, corpus):
vocabulary = Counter(word for document in corpus for word in document)
if self.max_words:
vocabulary = {word: freq for word,
freq in vocabulary.most_common(self.max_words)}
if self.min_frequency:
vocabulary = {word: freq for word, freq in vocabulary.items()
if freq >= self.min_frequency}
self.vocabulary = vocabulary
self.vocabulary_size = len(vocabulary) + 2 # padding and unk tokens
if self.start_end_tokens:
self.vocabulary_size += 2
def _build_word_index(self):
self.word2idx['<PAD>'] = 0
self.word2idx['<UNK>'] = 1
if self.start_end_tokens:
self.word2idx['<START>'] = 2
self.word2idx['<END>'] = 3
offset = len(self.word2idx)
for idx, word in enumerate(self.vocabulary):
self.word2idx[word] = idx + offset
self.idx2word = {idx: word for word, idx in self.word2idx.items()}
def fit(self, corpus):
if not self.maxlen:
self._find_max_document_length(corpus)
self._build_vocabulary(corpus)
self._build_word_index()
def pad_document_vector(self, vector):
padding = self.maxlen - len(vector)
vector.extend([self.word2idx['<PAD>']] * padding)
return vector
def add_start_end(self, vector):
vector.append(self.word2idx['<END>'])
return [self.word2idx['<START>']] + vector
def transform_document(self, document, offset=0):
"""
Vectorize a single document
"""
vector = [self.word2idx.get(word, self.word2idx['<UNK>'])
for word in document]
if len(vector) > self.maxlen:
vector = vector[:self.maxlen]
if self.start_end_tokens:
vector = self.add_start_end(vector)
vector = vector[offset:self.maxlen]
return self.pad_document_vector(vector)
def transform(self, corpus):
"""
Vectorizes a corpus in the form of a list of lists.
A corpus is a list of documents and a document is a list of words.
"""
return [self.transform_document(document) for document in corpus]
class ThreeBodyDataset(Dataset):
def __init__(self, path, vectorizer, tokenizer=None, stopwords=None):
self.corpus = pd.read_csv(path)
self.tokenizer = tokenizer
self.vectorizer = vectorizer
self.stopwords = stopwords
self._tokenize_corpus()
if self.stopwords: self._remove_stopwords()
self._vectorize_corpus()
def _remove_stopwords(self):
stopfilter = lambda doc: [word for word in doc if word not in self.stopwords]
self.corpus['tokens'] = self.corpus['tokens'].apply(stopfilter)
def _tokenize_corpus(self):
if self.tokenizer:
self.corpus['tokens'] = self.corpus['sentences'].apply(self.tokenizer)
else:
self.corpus['tokens'] = self.corpus['sentences'].apply(lambda x: x.lower().split())
def _vectorize_corpus(self):
if not self.vectorizer.vocabulary:
self.vectorizer.fit(self.corpus['tokens'])
self.corpus['vectors'] = self.corpus['tokens'].apply(self.vectorizer.transform_document)
self.corpus['target'] = self.corpus['tokens'].apply(self.vectorizer.transform_document, offset=1)
def __getitem__(self, index):
sentence = self.corpus['vectors'].iloc[index]
target = self.corpus['target'].iloc[index]
return torch.LongTensor(sentence), torch.LongTensor(target)
def __len__(self):
return len(self.corpus)
def simple_tokenizer(text):
return text.lower().split()
vectorizer = IndexVectorizer(max_words=None, min_frequency=min_word_freq,
start_end_tokens=True, maxlen=max_seq_len)
training_set = ThreeBodyDataset(train_file, vectorizer, simple_tokenizer)
test_set = ThreeBodyDataset(test_file, vectorizer, simple_tokenizer)
validation_set = ThreeBodyDataset(valid_file, vectorizer, simple_tokenizer)
len(training_set), len(validation_set), len(test_set)
(10387, 1732, 1731)
print("Vocab size: {}".format(vectorizer.vocabulary_size))
Vocab size: 1194
training_set.corpus.iloc[0]
sentences Purely in terms of command , humanity might ne... tokens [purely, in, terms, of, command, ,, humanity, ... vectors [2, 1, 4, 1, 5, 6, 7, 8, 9, 10, 11, 12, 1, 13,... target [1, 4, 1, 5, 6, 7, 8, 9, 10, 11, 12, 1, 13, 1,... Name: 0, dtype: object
We'll divide our dataset into a training set and a test set. We'll also create some functions for loading the data into batches.
def log(msg):
print(msg)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class History(object):
"""Records Loss and Validation Loss"""
def __init__(self):
self.reset()
def reset(self):
self.loss = dict()
self.val_loss = dict()
self.min_loss = 100
def update_min_loss(self, min_loss):
self.min_loss = min_loss
def update_loss(self, loss):
epoch = len(self.loss.keys())
self.loss[epoch] = loss
def update_val_loss(self, val_loss):
epoch = len(self.val_loss.keys())
self.val_loss[epoch] = val_loss
def plot(self):
loss = sorted(self.loss.items())
x, y = zip(*loss)
if self.val_loss:
val_loss = sorted(self.val_loss.items())
x1, y1 = zip(*val_loss)
plt.plot(x, y, 'C0', label='Loss')
plt.plot(x1, y1, 'C2', label='Validation Loss')
plt.legend();
else:
plt.plot(x, y, 'C0');
def categorical_accuracy(y_true, y_pred):
y_true = y_true.float()
_, y_pred = torch.max(y_pred.squeeze(), dim=-1)
return (y_pred.float() == y_true).float().mean()
def softmax_trick(x):
logits_exp = torch.exp(x - torch.max(x))
weights = torch.div(logits_exp, logits_exp.sum())
return weights
def save_state_dict(model, filepath):
'''Saves the model weights as a dictionary'''
model_dict = model.state_dict()
torch.save(model_dict, filepath)
return model_dict
Here we'll define a recurrent language model.
class RNNLM(nn.Module):
def __init__(self, vocab_size, seq_len, embedding_size,
hidden_size, batch_size,
dropout=.5, num_layers=1, tie_weights=False,
bidirectional=False, word2idx={}, log_softmax=False):
super(RNNLM, self).__init__()
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.batch_size = batch_size
self.seq_len = seq_len
self.tie_weights = tie_weights
self.num_layers = num_layers
self.num_directions = 1 if not bidirectional else 2
self.word2idx = word2idx
self.idx2word = {v:k for k,v in word2idx.items()}
# Model Pieces
self.dropout = nn.Dropout(p = dropout)
self.log_softmax = nn.LogSoftmax(dim = 1) if log_softmax else None
# Model Layers
self.encoder = nn.Embedding(vocab_size, embedding_size,
padding_idx = word2idx.get('<PAD>', 1))
self.lstm1 = nn.LSTM(embedding_size, hidden_size,
num_layers = 1,
bidirectional = bidirectional,
batch_first = True)
self.lstm2 = nn.LSTM(hidden_size * self.num_directions, hidden_size,
num_layers = 1,
bidirectional = bidirectional,
batch_first = True)
self.decoder = nn.Linear(hidden_size * self.num_directions, vocab_size)
# tie enc/dec weights
if self.tie_weights:
if hidden_size != embedding_size:
raise ValueError('When using the `tied` flag, hidden_size'
'must be equal to embedding_dim')
self.decoder.weight = self.encoder.weight
self.init_weights()
def init_hidden(self, bsz=None):
'''
For the nn.LSTM.
Defaults to the batchsize stored in the class params, but can take in an argument
in the case of sampling.
'''
if bsz == None:
bsz = self.batch_size
h0 = torch.zeros(self.num_layers * self.num_directions,
bsz, self.hidden_size).to(device)
c0 = torch.zeros(self.num_layers * self.num_directions,
bsz, self.hidden_size ).to(device)
return (h0, c0)
def init_weights(self):
initrange = 0.1
em_layer = [self.encoder]
lin_layers = [self.decoder]
for layer in lin_layers + em_layer:
layer.weight.data.uniform_(-initrange, initrange)
if layer in lin_layers:
layer.bias.data.fill_(0)
def sample(self, x_start):
'''
Generates a sequence of text given a starting word ix and hidden state.
'''
with torch.no_grad():
indices = [x_start]
for i in range(self.seq_len):
# create inputs
x_input = torch.LongTensor(indices).to(device)
x_embs = self.encoder(x_input.view(1, -1))
# send input through the rnn
output, hidden = self.lstm1(x_embs)
output, hidden = self.lstm2(output, hidden)
# format the last word of the rnn so we can softmax it.
one_dim_last_word = output.squeeze()[-1] if i > 0 else output.squeeze()
fwd = one_dim_last_word[ : self.hidden_size ]
bck = one_dim_last_word[ self.hidden_size : ]
# pick a word from the disto
word_weights = softmax_trick(fwd)
word_idx = torch.multinomial(word_weights, num_samples=1).squeeze().item()
indices.append(word_idx)
return indices
def forward(self, x, hidden, log_softmax=False):
'''
Iterates through the input, encodes it.
Each embedding is sent through the step function.
Dropout the last hidden layer and decode to a logit
x.size() #(bsz, seq_len)
logit.size # (bsz, seq_len, vocab_size)
equivalent to (output.size(0), output.size(1), logit.size(1)
'''
x_emb = self.encoder(x)
output, hidden = self.lstm1(x_emb, hidden)
output, hidden = self.lstm2(self.dropout(output), hidden)
logit = self.decoder(self.dropout(output))
if self.log_softmax:
logit = self.log_softmax(logit)
logit = logit.view(logit.size(0) * self.seq_len, self.vocab_size)
return logit, hidden
def run_epoch(model, dataset, criterion, optim, batch_size,
train=False, shuffle=True):
'''A wrapper for a training, validation or test run.'''
model.train() if train else model.eval()
loss = AverageMeter()
accuracy = AverageMeter()
loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
for X, y in loader:
model.zero_grad()
X = X.squeeze().to(device)
y = y.squeeze().view(-1).to(device)
# get a predition
hidden = model.init_hidden(X.size(0))
y_, hidden = model(X, hidden)
# calculate loss and accuracy
lossy = criterion(y_.squeeze(), y.squeeze())
accy = categorical_accuracy(y_.squeeze().data, y.squeeze().data)
loss.update(lossy.data.item())
accuracy.update(accy)
# backprop
if train:
lossy.backward()
optim.step()
return loss.avg, accuracy.avg
def training_epoch(*args, **kwargs):
'''Training Epoch'''
return run_epoch(train=True, *args, **kwargs)
def validation_epoch(*args, **kwargs):
'''Validation Epoch'''
return run_epoch(*args, **kwargs)
def sample_lm(model):
'''Samples a language model and returns generated words'''
start_idx = model.word2idx['<START>']
indices = model.sample(start_idx)
words = [model.idx2word[index] for index in indices]
return words
def training_loop(batch_size, num_epochs, display_freq, model, criterion,
optim, training_set, validation_set=None,
best_model_path='model', history=None):
'''Training iteration.'''
if not history:
history = History()
try:
for epoch in tqdm(range(num_epochs)):
loss, accuracy = training_epoch(model, training_set, criterion, optim, batch_size)
history.update_loss(loss)
if validation_set:
val_loss, val_accuracy = validation_epoch(model, validation_set,
criterion, optim, batch_size)
history.update_val_loss(val_loss)
if val_loss < history.min_loss:
save_state_dict(model, best_model_path)
history.update_min_loss(val_loss)
else:
if loss < history.min_loss:
save_state_dict(model, best_model_path)
history.update_min_loss(loss)
if epoch % display_freq == 0:
# display stats
if validation_set:
log("Epoch: {:04d}; Loss: {:.4f}; Val-Loss {:.4f}; "
"Perplexity {:.4f}; Val-Perplexity {:.4f}".format(
epoch, loss, val_loss, np.exp(loss), np.exp(val_loss)))
else:
log("Epoch: {:04d}; Loss: {:.4f}; Perplexity {:.4f};".format(
epoch, loss, np.exp(loss)))
# sample from the language model
words = sample_lm(model)
log("Sample: {}".format(' '.join(words)))
time.sleep(1)
log('-' * 89)
log("Training complete")
log("Lowest loss: {:.4f}".format(history.min_loss))
return history
except KeyboardInterrupt:
log('-' * 89)
log('Exiting from training early')
log("Lowest loss: {:.4f}".format(history.min_loss))
return history
We'll declare hyperparameters here, instantiate our model, create a training set data batcher, and train our model.
# Set Seed
if use_gpu: torch.cuda.manual_seed(303)
else: torch.manual_seed(303)
# set up Files to save stuff in
runtime = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
file_model = file_model.format(runtime)
file_wv = file_wv.format(runtime)
# Model Hyper Parameters
hidden_dim = 100
embedding_dim = 200
batch_size = 512
dropout = 0.2
lstm_layers = 1 # this is useless atm
lstm_bidirection = True
# Training
learning_rate = 1e-4
num_epochs = 300
display_epoch_freq = 10
# Build and initialize the model
lm = RNNLM(vectorizer.vocabulary_size, max_seq_len, embedding_dim, hidden_dim, batch_size,
dropout = dropout,
tie_weights = False,
num_layers = lstm_layers,
bidirectional = lstm_bidirection,
word2idx = vectorizer.word2idx,
log_softmax = True)
if use_gpu:
lm = lm.to(device)
lm.init_weights()
# Loss and Optimizer
loss = nn.NLLLoss()
optimizer = torch.optim.Adam(lm.parameters(), lr=learning_rate)
# Train the model
history = training_loop(batch_size, num_epochs, display_epoch_freq,
lm, loss, optimizer, training_set, validation_set,
best_model_path=file_model)
HBox(children=(IntProgress(value=0, max=300), HTML(value='')))
Epoch: 0000; Loss: 3.3905; Val-Loss 3.3774; Perplexity 29.6801; Val-Perplexity 29.2938 Sample: <START> day about away they quickly morning lit away these image use use day in three quickly - . command <START> by see but transmission this her this <PAD> something transmission
Exception in thread Thread-7: Traceback (most recent call last): File "/home/yvan/anaconda3/envs/leon_pytorch/lib/python3.6/threading.py", line 916, in _bootstrap_inner self.run() File "/home/yvan/anaconda3/envs/leon_pytorch/lib/python3.6/site-packages/tqdm/_monitor.py", line 62, in run for instance in self.tqdm_cls._instances: File "/home/yvan/anaconda3/envs/leon_pytorch/lib/python3.6/_weakrefset.py", line 60, in __iter__ for itemref in self.data: RuntimeError: Set changed size during iteration
Epoch: 0010; Loss: 2.7412; Val-Loss 2.7036; Perplexity 15.5052; Val-Perplexity 14.9334 Sample: <START> soon cigarette star his ji luo an about history had <PAD> got might but command in this children walking was children an when and quickly but with centuries star glanced Epoch: 0020; Loss: 2.6513; Val-Loss 2.6180; Perplexity 14.1722; Val-Perplexity 13.7078 Sample: <START> space , . her see transmission time walking several <PAD> his it edge speed seemed not walking was might out display it , something lit and can morning up soon Epoch: 0030; Loss: 2.3643; Val-Loss 2.3149; Perplexity 10.6371; Val-Perplexity 10.1237 Sample: <START> <END> was ! they when turned five universe this these <START> a quickly seemed display lit of but several never wang he held its want about by use something red Epoch: 0040; Loss: 2.0659; Val-Loss 2.0162; Perplexity 7.8925; Val-Perplexity 7.5095 Sample: <START> as wang humanity already see time look in walking can by five up <END> red with turned space away you transmission with lit his these got a three in want Epoch: 0050; Loss: 1.8465; Val-Loss 1.7997; Perplexity 6.3377; Val-Perplexity 6.0475 Sample: <START> might star morning it as about noticed morning on had history started might command by see with three speed command never never out in image , was at was turned Epoch: 0060; Loss: 1.6731; Val-Loss 1.6281; Perplexity 5.3288; Val-Perplexity 5.0940 Sample: <START> that never her direction soon was ! already but only past universe in something in universe glanced space space space his in about seemed five in in you a two Epoch: 0070; Loss: 1.5413; Val-Loss 1.5004; Perplexity 4.6705; Val-Perplexity 4.4834 Sample: <START> see to it direction it you never lit might red hands these luo not started transmission : however its up children saw saw held this the lit transmission ! like Epoch: 0080; Loss: 1.4411; Val-Loss 1.4104; Perplexity 4.2254; Val-Perplexity 4.0975 Sample: <START> out want the ji two but two five ! by got wang morning out might out have look battle . day use away were battle lit were her on had Epoch: 0090; Loss: 1.3598; Val-Loss 1.3254; Perplexity 3.8956; Val-Perplexity 3.7637 Sample: <START> not hands only transmission was - look however the was lit got like her transmission were half hands of they have , as the look they look however seemed centuries Epoch: 0100; Loss: 1.2966; Val-Loss 1.2670; Perplexity 3.6568; Val-Perplexity 3.5501 Sample: <START> something hands in hands image hands might past saw it you started star lit glanced about transmission by was lit had got lit as at not image red something trisolaris Epoch: 0110; Loss: 1.2488; Val-Loss 1.2281; Perplexity 3.4861; Val-Perplexity 3.4148 Sample: <START> , you light five - by noticed lit and <PAD> want three but walking got have transmission these turned with wang out her cigarette seemed two <START> an half transmission Epoch: 0120; Loss: 1.2097; Val-Loss 1.1952; Perplexity 3.3524; Val-Perplexity 3.3042 Sample: <START> centuries her battle were out speed soon half held cigarette on seemed might glanced morning when transmission two as noticed never five see in turned be that use to you Epoch: 0130; Loss: 1.1852; Val-Loss 1.1774; Perplexity 3.2712; Val-Perplexity 3.2460 Sample: <START> as can with her these seemed its history noticed <START> direction to past something out display trisolaris light red were was - day got might noticed in something use universe Epoch: 0140; Loss: 1.1666; Val-Loss 1.1600; Perplexity 3.2112; Val-Perplexity 3.1898 Sample: <START> to got its star seemed cigarette look had something . have had star with not two that <END> trisolaris and light her luo : light direction like an be transmission Epoch: 0150; Loss: 1.1460; Val-Loss 1.1398; Perplexity 3.1456; Val-Perplexity 3.1260 Sample: <START> look about lit it glanced wang time these like direction <START> they about at seemed his image day out with he of not when ji its cigarette humanity day not Epoch: 0160; Loss: 1.1350; Val-Loss 1.1276; Perplexity 3.1112; Val-Perplexity 3.0882 Sample: <START> something <END> display started time want display away it held humanity saw luo command at she not saw started morning be several children hands cigarette these past day luo however Epoch: 0170; Loss: 1.1251; Val-Loss 1.1225; Perplexity 3.0805; Val-Perplexity 3.0726 Sample: <START> was time to in on not lit her - look want his wang only got past however like several however star he seemed had held the and glanced when star Epoch: 0180; Loss: 1.1154; Val-Loss 1.1086; Perplexity 3.0507; Val-Perplexity 3.0301 Sample: <START> on about not have on several <END> might image glanced hands that turned might universe with held had ! want saw these noticed cigarette cigarette <UNK> they children not these Epoch: 0190; Loss: 1.1087; Val-Loss 1.1054; Perplexity 3.0303; Val-Perplexity 3.0203 Sample: <START> her not humanity to however speed want a in cigarette in these you these seemed had out its with speed lit - was several were only its ! image soon Epoch: 0200; Loss: 1.1046; Val-Loss 1.1010; Perplexity 3.0180; Val-Perplexity 3.0073 Sample: <START> her but however direction red : soon five it have morning glanced the <START> a seemed direction this want noticed on away might his <START> its glanced be three speed Epoch: 0210; Loss: 1.0969; Val-Loss 1.0958; Perplexity 2.9948; Val-Perplexity 2.9915 Sample: <START> direction transmission past saw started however held the see of of a seemed day only as morning space about you transmission look like a have was cigarette you star lit Epoch: 0220; Loss: 1.0922; Val-Loss 1.0979; Perplexity 2.9809; Val-Perplexity 2.9980 Sample: <START> luo never can luo - turned to have morning something the when cigarette want wang see you an the direction an transmission on its its got already you its edge Epoch: 0230; Loss: 1.0894; Val-Loss 1.0973; Perplexity 2.9725; Val-Perplexity 2.9961 Sample: <START> three light you this see speed started transmission might he something glanced use with light might an history star as past speed noticed you battle something an by her . Epoch: 0240; Loss: 1.0848; Val-Loss 1.0963; Perplexity 2.9588; Val-Perplexity 2.9930 Sample: <START> wang see morning seemed centuries by these <PAD> a humanity like seemed morning and were be past started morning ! <PAD> of three cigarette these ! about only held when Epoch: 0250; Loss: 1.0828; Val-Loss 1.0886; Perplexity 2.9529; Val-Perplexity 2.9702 Sample: <START> never never however . its an never with held on want got when with had an started two her command speed three held noticed its <UNK> see saw speed about Epoch: 0260; Loss: 1.0798; Val-Loss 1.0921; Perplexity 2.9441; Val-Perplexity 2.9805 Sample: <START> wang at up morning red its these cigarette only <END> held not command five noticed seemed at got his seemed in as battle - something edge and quickly day by Epoch: 0270; Loss: 1.0771; Val-Loss 1.0810; Perplexity 2.9362; Val-Perplexity 2.9475 Sample: <START> <UNK> only a transmission like cigarette her glanced history have cigarette you had had see walking you her glanced two soon its trisolaris but the three red noticed want not Epoch: 0280; Loss: 1.0761; Val-Loss 1.0786; Perplexity 2.9333; Val-Perplexity 2.9406 Sample: <START> her her in seemed to transmission had - got image saw see was was might but as speed centuries soon had like had never her never had <PAD> they day Epoch: 0290; Loss: 1.0753; Val-Loss 1.0849; Perplexity 2.9308; Val-Perplexity 2.9591 Sample: <START> had not you centuries red cigarette history with red hands its red when was trisolaris had held saw the by have command battle had hands speed soon were time five ----------------------------------------------------------------------------------------- Training complete Lowest loss: 1.0744
history.plot()
def test_loop(batch_size, model, criterion, optim, test_set,):
'''Data iterator for the test set'''
model.eval()
try:
test_loss, test_accuracy = validation_epoch(model, test_set,
criterion, optim, batch_size)
log('Evaluation Complete')
log('Test set Loss: {}'.format(test_loss))
log('Test set Perplexity: {}'.format(np.exp(test_loss)))
log('Test set Accuracy: {}'.format(test_accuracy))
except KeyboardInterrupt:
log('-' * 89)
log('Exiting from testing early')
lm.load_state_dict(torch.load(file_model))
test_loop(256, lm, loss, optimizer, test_set)
Evaluation Complete Test set Loss: 1.0961862802505493 Test set Perplexity: 2.992730795887753 Test set Accuracy: 0.0
word2vec
)word2vec
).def get_elmo_vectors(model):
model.eval()
word2vec = {}
for w, ix in tqdm(model.word2idx.items()):
ix_var = torch.LongTensor([ix]).to(device)
emb = model.encoder(ix_var.view(1,-1))
out1, hid_ = model.lstm1(emb)
out2, hid_ = model.lstm2(out1, hid_)
emb = emb.view(1,-1).data.cpu().numpy()[0]
out1 = out1.view(1,-1).data.cpu().numpy()[0]
out2 = out2.view(1,-1).data.cpu().numpy()[0]
word2vec[w] = dict(
embedding = emb,
hid1 = out1,
hid2 = out2,
)
return word2vec
word2vec = get_elmo_vectors(lm)
HBox(children=(IntProgress(value=0, max=1194), HTML(value='')))
with open(file_wv, 'w+') as f:
for w, vec in tqdm(word2vec.items()):
row = np.concatenate([[w], vec['embedding'], vec['hid1'], vec['hid2']])
f.write(' '.join(row) + '\n')
HBox(children=(IntProgress(value=0, max=1194), HTML(value='')))
' '.join(row)
'doomsday 0.019144375 -0.011262503 -0.054264653 0.07552316 -0.04677312 -0.043715246 0.029800449 -0.098677255 -0.087699786 0.111737646 0.08108406 0.08127617 0.14914002 -0.17581704 0.010354878 0.12268086 -0.08209066 -0.03954192 -0.009683989 0.060234316 0.079789676 -0.11663191 0.124724835 -0.037095707 0.04702442 -0.029397946 -0.12079265 -0.10139638 0.0095835645 0.113344245 0.08293821 0.100370824 -0.008190885 0.008966734 0.021829829 -0.08489138 0.10120566 -0.010230141 -0.04472376 -0.10047493 -0.0625175 -0.055764224 0.1547271 -0.1139233 -0.057975918 -0.11362845 -0.047995884 -0.028644199 0.031366825 0.09102691 0.0064217034 0.0591397 -0.034809817 -0.030157201 -0.055808205 0.017484395 0.057290915 0.056935456 0.04440339 0.03640528 0.061410222 -0.06576907 -0.066208206 0.009575155 0.09143286 -0.10088797 -0.08211959 0.059557542 -0.035395663 0.024542004 -0.035210326 0.014628755 0.12772882 -0.057329144 -0.0744061 -0.02468062 0.009410101 -0.04309125 -0.084312454 0.0762925 -0.020742945 0.04448962 0.12309611 0.016032873 0.15162481 -0.1314242 0.12193378 0.08225759 0.029561715 -0.0614505 0.006002207 -0.0927667 -0.065466516 -0.06480242 0.0077024833 0.020869521 -0.0758709 0.090688854 0.03706296 0.0017705106 0.043391243 -0.14248165 0.05658515 0.07355731 0.07907174 0.019258726 -0.08251391 -0.061569046 -0.07022654 0.038820278 -0.0878327 -0.041754447 0.014912813 -0.015131595 0.0001365648 -0.00984699 -0.11710193 0.027462937 -0.07352145 0.033495378 0.020122392 0.095105104 0.0015523405 -0.05174806 -0.055938885 0.08180001 -0.115246415 0.039961703 -0.047477294 0.021761438 -0.01935918 0.037725147 -0.045951545 0.059311435 0.003186048 -0.09801708 0.100893565 0.050895005 0.09025452 0.054864142 0.109954745 0.16359329 -0.036519982 0.10660204 -0.070585296 -0.03825811 0.008745374 -0.006380331 0.066654995 0.0025552027 -0.053807035 0.038850956 -0.031916663 0.08099706 -0.010193364 -0.12841539 -0.019125786 -0.034416553 0.056395948 -0.0029892267 0.034501243 0.058722064 -0.043787777 -0.14189883 -0.06835651 -0.011339926 0.071666695 -0.053338896 0.074751765 0.07303199 -0.080664754 -0.0674444 0.021041414 -0.11712903 0.015653785 -0.10301745 0.13191797 0.03336045 0.073667005 0.02220278 0.018015405 -0.04430134 -0.050868444 0.13137431 -0.12607768 -0.04385574 0.10287586 -0.02199097 -0.046265874 -0.046365585 0.03192538 -0.05965421 -0.008514234 0.06323657 0.11607088 0.043476973 0.006071745 -0.03002573 0.079065256 0.10162905 0.123988666 -0.11351591 0.012707224 0.033863965 -0.038223825 0.04719314 -0.04606858 -0.011751718 0.04507148 -0.0048624175 0.042390265 -0.13216378 -0.076352164 0.0075225094 0.061840795 0.2257554 -0.06922411 -0.089413665 -0.022556232 -0.07062854 -0.11332289 -0.007943083 0.10174688 0.016032528 0.076806374 0.016981956 0.071279734 0.00037128598 -0.0088539515 -0.0195572 -0.1138637 -0.037353866 -0.043936227 0.16313241 -0.05448641 0.14080141 0.015675344 -0.09355434 0.033032957 0.13163991 -0.04940434 0.027297204 0.011381093 0.07453263 -0.027582524 0.047562845 -0.17956817 -0.15048222 -0.075844266 0.164489 0.0888622 -0.054217402 -0.11419919 -0.046849623 0.0032962284 0.12490391 0.02095681 0.07558962 -0.0038533907 -0.07554676 0.07734367 -0.08135916 -0.013699127 0.048568472 -0.0627778 0.10186819 -0.08268394 -0.09364053 -0.11943724 -0.047317725 0.02201621 0.045001548 0.09462936 0.048074357 -0.05612918 0.0825079 -0.088830166 -0.03815688 -0.09791671 0.00797545 0.026351817 -0.091244705 -0.011592205 -0.05384624 0.0487457 -0.11188113 -0.05003087 -0.009712308 0.02694161 -0.009335604 0.024024239 -0.1270577 -0.14116967 -0.027769063 0.06540114 0.07800534 -0.08499379 0.076454766 0.09255872 -0.12363913 -0.101437315 0.007894608 0.17152315 0.00042451595 0.024041355 -0.041437272 0.050027844 -0.02760984 0.021333607 -0.049038976 0.03679542 -0.14440495 -0.029170137 -0.0027979224 -0.004392188 0.015992261 -0.021676196 -0.1323027 -0.0015445176 0.21564086 -0.013431105 -0.0061001927 -0.032708876 -0.04209167 0.054932624 -0.01859128 -0.11102446 -0.09928766 0.043740373 0.09106052 0.064958274 -0.06324171 0.1484803 0.21339355 -0.15271601 0.055790193 -0.07014506 0.020499673 0.0046761455 -0.009467079 0.06724552 0.17583278 -0.20661941 0.021186551 0.13845748 -0.019551588 0.1073641 -0.032113653 0.12176157 0.04607929 0.07857493 0.04625448 -0.0724659 -0.13722306 -0.10848629 -0.015127993 -0.052834637 -0.080647975 -0.18753123 0.13316339 -0.18373246 0.044140592 0.030956859 0.082465254 -0.09925594 0.029463645 0.048026074 0.06725748 -0.022498338 0.30080742 -0.04067377 0.18708871 -0.05891599 -0.034532316 -0.02673782 -0.0774335 -0.03408435 -0.07535319 -0.16249557 0.15918423 -0.045513082 -0.026489092 0.063562505 0.0035445625 0.04143385 0.054864183 0.11467244 -0.006415189 -0.017656647 0.02820999 -0.12635364 -0.017223196 -0.020308511 0.09518707 -0.20856085 -0.10820239 -0.118585385 0.051283766 -0.01944336 0.010883282 -0.15552144 -0.026531573 -0.0022704673 -0.07406549 0.15039048 -0.09156772 0.19901429 -0.11889205 -0.027131049 0.13789564 0.21246442 0.085902646 0.030998515 -0.18561736 0.13175236 0.10126359 0.06693361 -0.14071806 0.09460603 -0.20043278 0.08795798 0.1279674 -0.14521612 -0.08201606 -0.07542526 0.045425393 0.25055277 -0.2020613 -0.09616534 -0.12523662 0.059590247 -0.08095756 0.13505952 0.014808842 -0.039542053 0.37339324 -0.042027168 -0.032162536 0.11853116 0.1960998 -0.17258637 -0.044939574 0.05499272 -0.0019374555 0.0815436 -0.06386393 0.15047418 -0.2258666 -0.062035587 -0.03271413 0.046317052 0.15558659 -0.15382297 -0.13178296 -0.14106806 0.054497693 -0.10640273 0.2416711 -0.08302012 -0.14271438 -0.0866469 0.1034047 -0.018578123 -0.14366029 0.08044761 -0.015542477 0.209219 -0.2218885 -0.04830695 0.1617367 0.100623764 0.12184961 -0.062163666 0.19608174 -0.07673878 -0.04012129 -0.22895329 -0.13467328 0.012120137 0.040390603 0.10485775 0.15833047 0.19763793 0.0932549 0.07177312 -0.05330936 -0.025881646 0.01982425 -0.050485928 0.17094877 0.061567124 -0.19522922 -0.11896508 0.067274734 -0.04687178 0.025645025 -0.1969956 -0.07138407 0.00557158 -0.27678615 -0.22383045 -0.27760637 0.13022874 0.10331432 -0.08814418 0.23716603 0.023350634 -0.28934518 0.15164366 -0.08098609 0.13166778 -0.1763806 -0.16157448 0.028121786 0.14172442 0.06670422 0.07928535 0.008279952 -0.0639333 0.17381434 -0.08439902 0.2515398 -0.04096333 -0.073651336 0.21865274 0.18785118 0.22600447 0.16057311 0.11348682 -0.10273971 0.04218584 0.008524662 -0.10813271 0.066350296 -0.09815816 0.12081578 -0.030189963 0.17794281 -0.17870511 0.25373653 -0.037956443 -0.0125516895 0.03886509 -0.08524119 0.28474876 0.09443538 0.28689948 0.25865027 0.16293843 -0.1992543 -0.04438634 -0.15152445 0.10258499 -0.08918362 0.13862842 0.19983067 0.1311691 0.0499539 0.12701514 -0.09394877 -0.19378224 -0.060368024 -0.31211683 -0.18064882 0.09889215 -0.1825807 0.21072602 -0.008927809 -0.093034856 -0.10520393 0.14345038 0.016104385 0.13663946 -0.09954506 -0.086325355 0.087362535 -0.33888233 -0.20934108 0.14313582 0.17845821 -0.25092235 0.10741558 -0.036945153 0.13091739 -0.29757756 -0.091621876 -0.06711745 -0.16859876 0.22402763 -0.21097897 0.13412702 0.12645882 0.13199389 0.07841743 -0.03782335 -0.24159867 -0.23621875 0.10574535 0.15957592 0.15734378'