#!/usr/bin/env python # coding: utf-8 # # Gensim Doc2vec Tutorial on the IMDB Sentiment Dataset # ## Introduction # # In this tutorial, we will learn how to apply Doc2vec using gensim by recreating the results of Le and Mikolov 2014. # # ### Bag-of-words Model # Previous state-of-the-art document representations were based on the bag-of-words model, which represent input documents as a fixed-length vector. For example, borrowing from the Wikipedia article, the two documents # (1) `John likes to watch movies. Mary likes movies too.` # (2) `John also likes to watch football games.` # are used to construct a length 10 list of words # `["John", "likes", "to", "watch", "movies", "Mary", "too", "also", "football", "games"]` # so then we can represent the two documents as fixed length vectors whose elements are the frequencies of the corresponding words in our list # (1) `[1, 2, 1, 1, 2, 1, 1, 0, 0, 0]` # (2) `[1, 1, 1, 1, 0, 0, 0, 1, 1, 1]` # Bag-of-words models are surprisingly effective but still lose information about word order. Bag of n-grams models consider word phrases of length n to represent documents as fixed-length vectors to capture local word order but suffer from data sparsity and high dimensionality. # # ### Word2vec Model # Word2vec is a more recent model that embeds words in a high-dimensional vector space using a shallow neural network. The result is a set of word vectors where vectors close together in vector space have similar meanings based on context, and word vectors distant to each other have differing meanings. For example, `strong` and `powerful` would be close together and `strong` and `Paris` would be relatively far. There are two versions of this model based on skip-grams and continuous bag of words. # # # #### Word2vec - Skip-gram Model # The skip-gram word2vec model, for example, takes in pairs (word1, word2) generated by moving a window across text data, and trains a 1-hidden-layer neural network based on the fake task of given an input word, giving us a predicted probability distribution of nearby words to the input. The hidden-to-output weights in the neural network give us the word embeddings. So if the hidden layer has 300 neurons, this network will give us 300-dimensional word embeddings. We use one-hot encoding for the words. # # #### Word2vec - Continuous-bag-of-words Model # Continuous-bag-of-words Word2vec is very similar to the skip-gram model. It is also a 1-hidden-layer neural network. The fake task is based on the input context words in a window around a center word, predict the center word. Again, the hidden-to-output weights give us the word embeddings and we use one-hot encoding. # # ### Paragraph Vector # Le and Mikolov 2014 introduces the Paragraph Vector, which outperforms more naïve representations of documents such as averaging the Word2vec word vectors of a document. The idea is straightforward: we act as if a paragraph (or document) is just another vector like a word vector, but we will call it a paragraph vector. We determine the embedding of the paragraph in vector space in the same way as words. Our paragraph vector model considers local word order like bag of n-grams, but gives us a denser representation in vector space compared to a sparse, high-dimensional representation. # # #### Paragraph Vector - Distributed Memory (PV-DM) # This is the Paragraph Vector model analogous to Continuous-bag-of-words Word2vec. The paragraph vectors are obtained by training a neural network on the fake task of inferring a center word based on context words and a context paragraph. A paragraph is a context for all words in the paragraph, and a word in a paragraph can have that paragraph as a context. # # #### Paragraph Vector - Distributed Bag of Words (PV-DBOW) # This is the Paragraph Vector model analogous to Skip-gram Word2vec. The paragraph vectors are obtained by training a neural network on the fake task of predicting a probability distribution of words in a paragraph given a randomly-sampled word from the paragraph. # # ### Requirements # The following python modules are dependencies for this tutorial: # * testfixtures ( `pip install testfixtures` ) # * statsmodels ( `pip install statsmodels` ) # ## Load corpus # Let's download the IMDB archive if it is not already downloaded (84 MB). This will be our text data for this tutorial. # The data can be found here: http://ai.stanford.edu/~amaas/data/sentiment/ # In[1]: import locale import glob import os.path import requests import tarfile import sys import codecs import smart_open dirname = 'aclImdb' filename = 'aclImdb_v1.tar.gz' locale.setlocale(locale.LC_ALL, 'C') if sys.version > '3': control_chars = [chr(0x85)] else: control_chars = [unichr(0x85)] # Convert text to lower-case and strip punctuation/symbols from words def normalize_text(text): norm_text = text.lower() # Replace breaks with spaces norm_text = norm_text.replace('
', ' ') # Pad punctuation with spaces on both sides for char in ['.', '"', ',', '(', ')', '!', '?', ';', ':']: norm_text = norm_text.replace(char, ' ' + char + ' ') return norm_text import time start = time.clock() if not os.path.isfile('aclImdb/alldata-id.txt'): if not os.path.isdir(dirname): if not os.path.isfile(filename): # Download IMDB archive print("Downloading IMDB archive...") url = u'http://ai.stanford.edu/~amaas/data/sentiment/' + filename r = requests.get(url) with open(filename, 'wb') as f: f.write(r.content) tar = tarfile.open(filename, mode='r') tar.extractall() tar.close() # Concatenate and normalize test/train data print("Cleaning up dataset...") folders = ['train/pos', 'train/neg', 'test/pos', 'test/neg', 'train/unsup'] alldata = u'' for fol in folders: temp = u'' output = fol.replace('/', '-') + '.txt' # Is there a better pattern to use? txt_files = glob.glob(os.path.join(dirname, fol, '*.txt')) for txt in txt_files: with smart_open.smart_open(txt, "rb") as t: t_clean = t.read().decode("utf-8") for c in control_chars: t_clean = t_clean.replace(c, ' ') temp += t_clean temp += "\n" temp_norm = normalize_text(temp) with smart_open.smart_open(os.path.join(dirname, output), "wb") as n: n.write(temp_norm.encode("utf-8")) alldata += temp_norm with smart_open.smart_open(os.path.join(dirname, 'alldata-id.txt'), 'wb') as f: for idx, line in enumerate(alldata.splitlines()): num_line = u"_*{0} {1}\n".format(idx, line) f.write(num_line.encode("utf-8")) end = time.clock() print ("Total running time: ", end-start) # In[2]: import os.path assert os.path.isfile("aclImdb/alldata-id.txt"), "alldata-id.txt unavailable" # The text data is small enough to be read into memory. # In[3]: import gensim from gensim.models.doc2vec import TaggedDocument from collections import namedtuple SentimentDocument = namedtuple('SentimentDocument', 'words tags split sentiment') alldocs = [] # Will hold all docs in original order with open('aclImdb/alldata-id.txt', encoding='utf-8') as alldata: for line_no, line in enumerate(alldata): tokens = gensim.utils.to_unicode(line).split() words = tokens[1:] tags = [line_no] # 'tags = [tokens[0]]' would also work at extra memory cost split = ['train', 'test', 'extra', 'extra'][line_no//25000] # 25k train, 25k test, 25k extra sentiment = [1.0, 0.0, 1.0, 0.0, None, None, None, None][line_no//12500] # [12.5K pos, 12.5K neg]*2 then unknown alldocs.append(SentimentDocument(words, tags, split, sentiment)) train_docs = [doc for doc in alldocs if doc.split == 'train'] test_docs = [doc for doc in alldocs if doc.split == 'test'] doc_list = alldocs[:] # For reshuffling per pass print('%d docs: %d train-sentiment, %d test-sentiment' % (len(doc_list), len(train_docs), len(test_docs))) # ## Set-up Doc2Vec Training & Evaluation Models # We approximate the experiment of Le & Mikolov ["Distributed Representations of Sentences and Documents"](http://cs.stanford.edu/~quocle/paragraph_vector.pdf) with guidance from Mikolov's [example go.sh](https://groups.google.com/d/msg/word2vec-toolkit/Q49FIrNOQRo/J6KG8mUj45sJ): # # `./word2vec -train ../alldata-id.txt -output vectors.txt -cbow 0 -size 100 -window 10 -negative 5 -hs 0 -sample 1e-4 -threads 40 -binary 0 -iter 20 -min-count 1 -sentence-vectors 1` # # We vary the following parameter choices: # * 100-dimensional vectors, as the 400-d vectors of the paper don't seem to offer much benefit on this task # * Similarly, frequent word subsampling seems to decrease sentiment-prediction accuracy, so it's left out # * `cbow=0` means skip-gram which is equivalent to the paper's 'PV-DBOW' mode, matched in gensim with `dm=0` # * Added to that DBOW model are two DM models, one which averages context vectors (`dm_mean`) and one which concatenates them (`dm_concat`, resulting in a much larger, slower, more data-hungry model) # * A `min_count=2` saves quite a bit of model memory, discarding only words that appear in a single doc (and are thus no more expressive than the unique-to-each doc vectors themselves) # In[4]: from gensim.models import Doc2Vec import gensim.models.doc2vec from collections import OrderedDict import multiprocessing cores = multiprocessing.cpu_count() assert gensim.models.doc2vec.FAST_VERSION > -1, "This will be painfully slow otherwise" simple_models = [ # PV-DM w/ concatenation - window=5 (both sides) approximates paper's 10-word total window size Doc2Vec(dm=1, dm_concat=1, size=100, window=5, negative=5, hs=0, min_count=2, workers=cores), # PV-DBOW Doc2Vec(dm=0, size=100, negative=5, hs=0, min_count=2, workers=cores), # PV-DM w/ average Doc2Vec(dm=1, dm_mean=1, size=100, window=10, negative=5, hs=0, min_count=2, workers=cores), ] # Speed up setup by sharing results of the 1st model's vocabulary scan simple_models[0].build_vocab(alldocs) # PV-DM w/ concat requires one special NULL word so it serves as template print(simple_models[0]) for model in simple_models[1:]: model.reset_from(simple_models[0]) print(model) models_by_name = OrderedDict((str(model), model) for model in simple_models) # Le and Mikolov notes that combining a paragraph vector from Distributed Bag of Words (DBOW) and Distributed Memory (DM) improves performance. We will follow, pairing the models together for evaluation. Here, we concatenate the paragraph vectors obtained from each model. # In[5]: from gensim.test.test_doc2vec import ConcatenatedDoc2Vec models_by_name['dbow+dmm'] = ConcatenatedDoc2Vec([simple_models[1], simple_models[2]]) models_by_name['dbow+dmc'] = ConcatenatedDoc2Vec([simple_models[1], simple_models[0]]) # ## Predictive Evaluation Methods # Let's define some helper methods for evaluating the performance of our Doc2vec using paragraph vectors. We will classify document sentiments using a logistic regression model based on our paragraph embeddings. We will compare the error rates based on word embeddings from our various Doc2vec models. # In[6]: import numpy as np import statsmodels.api as sm from random import sample # For timing from contextlib import contextmanager from timeit import default_timer import time @contextmanager def elapsed_timer(): start = default_timer() elapser = lambda: default_timer() - start yield lambda: elapser() end = default_timer() elapser = lambda: end-start def logistic_predictor_from_data(train_targets, train_regressors): logit = sm.Logit(train_targets, train_regressors) predictor = logit.fit(disp=0) # print(predictor.summary()) return predictor def error_rate_for_model(test_model, train_set, test_set, infer=False, infer_steps=3, infer_alpha=0.1, infer_subsample=0.1): """Report error rate on test_doc sentiments, using supplied model and train_docs""" train_targets, train_regressors = zip(*[(doc.sentiment, test_model.docvecs[doc.tags[0]]) for doc in train_set]) train_regressors = sm.add_constant(train_regressors) predictor = logistic_predictor_from_data(train_targets, train_regressors) test_data = test_set if infer: if infer_subsample < 1.0: test_data = sample(test_data, int(infer_subsample * len(test_data))) test_regressors = [test_model.infer_vector(doc.words, steps=infer_steps, alpha=infer_alpha) for doc in test_data] else: test_regressors = [test_model.docvecs[doc.tags[0]] for doc in test_docs] test_regressors = sm.add_constant(test_regressors) # Predict & evaluate test_predictions = predictor.predict(test_regressors) corrects = sum(np.rint(test_predictions) == [doc.sentiment for doc in test_data]) errors = len(test_predictions) - corrects error_rate = float(errors) / len(test_predictions) return (error_rate, errors, len(test_predictions), predictor) # ## Bulk Training # We use an explicit multiple-pass, alpha-reduction approach as sketched in this [gensim doc2vec blog post](http://radimrehurek.com/2014/12/doc2vec-tutorial/) with added shuffling of corpus on each pass. # # Note that vector training is occurring on *all* documents of the dataset, which includes all TRAIN/TEST/DEV docs. # # We evaluate each model's sentiment predictive power based on error rate, and the evaluation is repeated after each pass so we can see the rates of relative improvement. The base numbers reuse the TRAIN and TEST vectors stored in the models for the logistic regression, while the _inferred_ results use newly-inferred TEST vectors. # # (On a 4-core 2.6Ghz Intel Core i7, these 20 passes training and evaluating 3 main models takes about an hour.) # In[7]: from collections import defaultdict best_error = defaultdict(lambda: 1.0) # To selectively print only best errors achieved # In[8]: from random import shuffle import datetime alpha, min_alpha, passes = (0.025, 0.001, 20) alpha_delta = (alpha - min_alpha) / passes print("START %s" % datetime.datetime.now()) for epoch in range(passes): shuffle(doc_list) # Shuffling gets best results for name, train_model in models_by_name.items(): # Train duration = 'na' train_model.alpha, train_model.min_alpha = alpha, alpha with elapsed_timer() as elapsed: train_model.train(doc_list, total_examples=len(doc_list), epochs=1) duration = '%.1f' % elapsed() # Evaluate eval_duration = '' with elapsed_timer() as eval_elapsed: err, err_count, test_count, predictor = error_rate_for_model(train_model, train_docs, test_docs) eval_duration = '%.1f' % eval_elapsed() best_indicator = ' ' if err <= best_error[name]: best_error[name] = err best_indicator = '*' print("%s%f : %i passes : %s %ss %ss" % (best_indicator, err, epoch + 1, name, duration, eval_duration)) if ((epoch + 1) % 5) == 0 or epoch == 0: eval_duration = '' with elapsed_timer() as eval_elapsed: infer_err, err_count, test_count, predictor = error_rate_for_model(train_model, train_docs, test_docs, infer=True) eval_duration = '%.1f' % eval_elapsed() best_indicator = ' ' if infer_err < best_error[name + '_inferred']: best_error[name + '_inferred'] = infer_err best_indicator = '*' print("%s%f : %i passes : %s %ss %ss" % (best_indicator, infer_err, epoch + 1, name + '_inferred', duration, eval_duration)) print('Completed pass %i at alpha %f' % (epoch + 1, alpha)) alpha -= alpha_delta print("END %s" % str(datetime.datetime.now())) # ## Achieved Sentiment-Prediction Accuracy # In[9]: # Print best error rates achieved print("Err rate Model") for rate, name in sorted((rate, name) for name, rate in best_error.items()): print("%f %s" % (rate, name)) # In our testing, contrary to the results of the paper, PV-DBOW performs best. Concatenating vectors from different models only offers a small predictive improvement over averaging vectors. There best results reproduced are just under 10% error rate, still a long way from the paper's reported 7.42% error rate. # ## Examining Results # ### Are inferred vectors close to the precalculated ones? # In[10]: doc_id = np.random.randint(simple_models[0].docvecs.count) # Pick random doc; re-run cell for more examples print('for doc %d...' % doc_id) for model in simple_models: inferred_docvec = model.infer_vector(alldocs[doc_id].words) print('%s:\n %s' % (model, model.docvecs.most_similar([inferred_docvec], topn=3))) # (Yes, here the stored vector from 20 epochs of training is usually one of the closest to a freshly-inferred vector for the same words. Note the defaults for inference are very abbreviated – just 3 steps starting at a high alpha – and likely need tuning for other applications.) # ### Do close documents seem more related than distant ones? # In[11]: import random doc_id = np.random.randint(simple_models[0].docvecs.count) # pick random doc, re-run cell for more examples model = random.choice(simple_models) # and a random model sims = model.docvecs.most_similar(doc_id, topn=model.docvecs.count) # get *all* similar documents print(u'TARGET (%d): «%s»\n' % (doc_id, ' '.join(alldocs[doc_id].words))) print(u'SIMILAR/DISSIMILAR DOCS PER MODEL %s:\n' % model) for label, index in [('MOST', 0), ('MEDIAN', len(sims)//2), ('LEAST', len(sims) - 1)]: print(u'%s %s: «%s»\n' % (label, sims[index], ' '.join(alldocs[sims[index][0]].words))) # (Somewhat, in terms of reviewer tone, movie genre, etc... the MOST cosine-similar docs usually seem more like the TARGET than the MEDIAN or LEAST.) # ### Do the word vectors show useful similarities? # In[12]: word_models = simple_models[:] # In[13]: import random from IPython.display import HTML # pick a random word with a suitable number of occurences while True: word = random.choice(word_models[0].wv.index2word) if word_models[0].wv.vocab[word].count > 10: break # or uncomment below line, to just pick a word from the relevant domain: #word = 'comedy/drama' similars_per_model = [str(model.most_similar(word, topn=20)).replace('), ','),
\n') for model in word_models] similar_table = ("
" + "".join([str(model) for model in word_models]) + "
" + "".join(similars_per_model) + "
") print("most similar words for '%s' (%d occurences)" % (word, simple_models[0].wv.vocab[word].count)) HTML(similar_table) # Do the DBOW words look meaningless? That's because the gensim DBOW model doesn't train word vectors – they remain at their random initialized values – unless you ask with the `dbow_words=1` initialization parameter. Concurrent word-training slows DBOW mode significantly, and offers little improvement (and sometimes a little worsening) of the error rate on this IMDB sentiment-prediction task. # # Words from DM models tend to show meaningfully similar words when there are many examples in the training data (as with 'plot' or 'actor'). (All DM modes inherently involve word vector training concurrent with doc vector training.) # ### Are the word vectors from this dataset any good at analogies? # In[14]: # Download this file: https://github.com/nicholas-leonard/word2vec/blob/master/questions-words.txt # and place it in the local directory # Note: this takes many minutes if os.path.isfile('question-words.txt'): for model in word_models: sections = model.accuracy('questions-words.txt') correct, incorrect = len(sections[-1]['correct']), len(sections[-1]['incorrect']) print('%s: %0.2f%% correct (%d of %d)' % (model, float(correct*100)/(correct+incorrect), correct, correct+incorrect)) # Even though this is a tiny, domain-specific dataset, it shows some meager capability on the general word analogies – at least for the DM/concat and DM/mean models which actually train word vectors. (The untrained random-initialized words of the DBOW model of course fail miserably.) # ## Slop # In[15]: This cell left intentionally erroneous. # To mix the Google dataset (if locally available) into the word tests... # In[ ]: from gensim.models import KeyedVectors w2v_g100b = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin.gz', binary=True) w2v_g100b.compact_name = 'w2v_g100b' word_models.append(w2v_g100b) # To get copious logging output from above steps... # In[ ]: import logging logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) rootLogger = logging.getLogger() rootLogger.setLevel(logging.INFO) # To auto-reload python code while developing... # In[ ]: get_ipython().run_line_magic('load_ext', 'autoreload') get_ipython().run_line_magic('autoreload', '2')