from lda2vec import preprocess, Corpus
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
try:
import seaborn
except:
pass
You must be using a very recent version of pyLDAvis to use the lda2vec outputs.
As of this writing, anything past Jan 6 2016 or this commit 14e7b5f60d8360eb84969ff08a1b77b365a5878e
should work.
You can do this quickly by installing it directly from master like so:
# pip install git+https://github.com/bmabey/[email protected]#egg=pyLDAvis
import pyLDAvis
pyLDAvis.enable_notebook()
After runnning lda2vec_run.py
script in examples/twenty_newsgroups/lda2vec
directory a topics.pyldavis.npz
will be created that contains the topic-to-word probabilities and frequencies. What's left is to visualize and label each topic from the it's prevalent words.
npz = np.load(open('topics.pyldavis.npz', 'r'))
dat = {k: v for (k, v) in npz.iteritems()}
dat['vocab'] = dat['vocab'].tolist()
# dat['term_frequency'] = dat['term_frequency'] * 1.0 / dat['term_frequency'].sum()
top_n = 10
topic_to_topwords = {}
for j, topic_to_word in enumerate(dat['topic_term_dists']):
top = np.argsort(topic_to_word)[::-1][:top_n]
msg = 'Topic %i ' % j
top_words = [dat['vocab'][i].strip()[:35] for i in top]
msg += ' '.join(top_words)
print msg
topic_to_topwords[j] = top_words
import warnings
warnings.filterwarnings('ignore')
prepared_data = pyLDAvis.prepare(dat['topic_term_dists'], dat['doc_topic_dists'],
dat['doc_lengths'] * 1.0, dat['vocab'], dat['term_frequency'] * 1.0, mds='tsne')
pyLDAvis.display(prepared_data)
The 20 newsgroups dataset is interesting because users effetively classify the topics by posting to a particular newsgroup. This lets us qualitatively check our unsupervised topics with the 'true' labels. For example, the four topics we highlighted above are intuitively close to comp.graphics, sci.med, talk.politics.misc, and sci.space.
comp.graphics
comp.os.ms-windows.misc
comp.sys.ibm.pc.hardware
comp.sys.mac.hardware
comp.windows.x
rec.autos
rec.motorcycles
rec.sport.baseball
rec.sport.hockey
sci.crypt
sci.electronics
sci.med
sci.space
misc.forsale
talk.politics.misc
talk.politics.guns
talk.politics.mideast
talk.religion.misc
alt.atheism
soc.religion.christian
from sklearn.datasets import fetch_20newsgroups
remove=('headers', 'footers', 'quotes')
texts = fetch_20newsgroups(subset='train', remove=remove).data
print texts[1]
msg = "{weight:02d}% in topic {topic_id:02d} which has top words {text:s}"
for topic_id, weight in enumerate(dat['doc_topic_dists'][1]):
if weight > 0.01:
text = ', '.join(topic_to_topwords[topic_id])
print msg.format(topic_id=topic_id, weight=int(weight * 100.0), text=text)
plt.bar(np.arange(20), dat['doc_topic_dists'][1])
print texts[51]
msg = "{weight:02d}% in topic {topic_id:02d} which has top words {text:s}"
for topic_id, weight in enumerate(dat['doc_topic_dists'][51]):
if weight > 0.01:
text = ', '.join(topic_to_topwords[topic_id])
print msg.format(topic_id=topic_id, weight=int(weight * 100.0), text=text)
plt.bar(np.arange(20), dat['doc_topic_dists'][51])