#!/usr/bin/env python # coding: utf-8 # In[184]: from lda2vec import preprocess, Corpus import matplotlib.pyplot as plt import numpy as np get_ipython().run_line_magic('matplotlib', 'inline') try: import seaborn except: pass # You must be using a very recent version of pyLDAvis to use the lda2vec outputs. # As of this writing, anything past Jan 6 2016 or this commit `14e7b5f60d8360eb84969ff08a1b77b365a5878e` should work. # You can do this quickly by installing it directly from master like so: # # In[ ]: # pip install git+https://github.com/bmabey/pyLDAvis.git@master#egg=pyLDAvis # In[11]: import pyLDAvis pyLDAvis.enable_notebook() # ### Reading in the saved model topics # After runnning `lda2vec_run.py` script in `examples/twenty_newsgroups/lda2vec` directory a `topics.pyldavis.npz` will be created that contains the topic-to-word probabilities and frequencies. What's left is to visualize and label each topic from the it's prevalent words. # In[157]: npz = np.load(open('topics.pyldavis.npz', 'r')) dat = {k: v for (k, v) in npz.iteritems()} dat['vocab'] = dat['vocab'].tolist() # dat['term_frequency'] = dat['term_frequency'] * 1.0 / dat['term_frequency'].sum() # In[189]: top_n = 10 topic_to_topwords = {} for j, topic_to_word in enumerate(dat['topic_term_dists']): top = np.argsort(topic_to_word)[::-1][:top_n] msg = 'Topic %i ' % j top_words = [dat['vocab'][i].strip()[:35] for i in top] msg += ' '.join(top_words) print msg topic_to_topwords[j] = top_words # ### Visualize topics # In[187]: import warnings warnings.filterwarnings('ignore') prepared_data = pyLDAvis.prepare(dat['topic_term_dists'], dat['doc_topic_dists'], dat['doc_lengths'] * 1.0, dat['vocab'], dat['term_frequency'] * 1.0, mds='tsne') # In[188]: pyLDAvis.display(prepared_data) # ### 'True' topics # The 20 newsgroups dataset is interesting because users effetively classify the topics by posting to a particular newsgroup. This lets us qualitatively check our unsupervised topics with the 'true' labels. For example, the four topics we highlighted above are intuitively close to comp.graphics, sci.med, talk.politics.misc, and sci.space. # comp.graphics # comp.os.ms-windows.misc # comp.sys.ibm.pc.hardware # comp.sys.mac.hardware # comp.windows.x # rec.autos # rec.motorcycles # rec.sport.baseball # rec.sport.hockey # sci.crypt # sci.electronics # sci.med # sci.space # misc.forsale # talk.politics.misc # talk.politics.guns # talk.politics.mideast # talk.religion.misc # alt.atheism # soc.religion.christian # ### Individual document topics # In[248]: from sklearn.datasets import fetch_20newsgroups remove=('headers', 'footers', 'quotes') texts = fetch_20newsgroups(subset='train', remove=remove).data # #### First Example # In[249]: print texts[1] # In[250]: msg = "{weight:02d}% in topic {topic_id:02d} which has top words {text:s}" for topic_id, weight in enumerate(dat['doc_topic_dists'][1]): if weight > 0.01: text = ', '.join(topic_to_topwords[topic_id]) print msg.format(topic_id=topic_id, weight=int(weight * 100.0), text=text) # In[251]: plt.bar(np.arange(20), dat['doc_topic_dists'][1]) # #### Second Example # In[255]: print texts[51] # In[259]: msg = "{weight:02d}% in topic {topic_id:02d} which has top words {text:s}" for topic_id, weight in enumerate(dat['doc_topic_dists'][51]): if weight > 0.01: text = ', '.join(topic_to_topwords[topic_id]) print msg.format(topic_id=topic_id, weight=int(weight * 100.0), text=text) # In[260]: plt.bar(np.arange(20), dat['doc_topic_dists'][51]) # In[ ]: