In this notebook, we will learn how to visualize topic clusters using dendrogram. Dendrogram is a tree-structured graph which can be used to visualize the result of a hierarchical clustering calculation. Hierarchical clustering puts individual data points into similarity groups, without prior knowledge of groups. We can use it to explore the topic models and see how the topics are connected to each other in a sequence of successive fusions or divisions that occur in the clustering process.
from gensim.models.ldamodel import LdaModel
from gensim.corpora import Dictionary
from gensim.parsing.preprocessing import remove_stopwords, strip_punctuation
import numpy as np
import pandas as pd
import re
import plotly.offline as py
import plotly.graph_objs as go
py.init_notebook_mode()
Using TensorFlow backend.
We'll use the fake news dataset from kaggle for this notebook. First step is to preprocess the data and train our topic model using LDA. You can refer to this notebook also for tips and suggestions of pre-processing the text data, and how to train LDA model for getting good results.
df_fake = pd.read_csv('fake.csv')
df_fake[['title', 'text', 'language']].head()
df_fake = df_fake.loc[(pd.notnull(df_fake.text)) & (df_fake.language == 'english')]
# remove stopwords and punctuations
def preprocess(row):
return strip_punctuation(remove_stopwords(row.lower()))
df_fake['text'] = df_fake['text'].apply(preprocess)
# Convert data to required input format by LDA
texts = []
for line in df_fake.text:
lowered = line.lower()
words = re.findall(r'\w+', lowered, flags=re.UNICODE|re.LOCALE)
texts.append(words)
# Create a dictionary representation of the documents.
dictionary = Dictionary(texts)
# Filter out words that occur less than 2 documents, or more than 30% of the documents.
dictionary.filter_extremes(no_below=2, no_above=0.4)
# Bag-of-words representation of the documents.
corpus_fake = [dictionary.doc2bow(text) for text in texts]
lda_fake = LdaModel(corpus=corpus_fake, id2word=dictionary, num_topics=35, passes=30, chunksize=1500, iterations=200, alpha='auto')
lda_fake.save('lda_35')
lda_fake = LdaModel.load('lda_35')
Firstly, a distance matrix is calculated to store distance between every topic pair. These distances are then used ascendingly to cluster the topics together whose process is depicted by the dendrogram.
# This input cell contains the modified code from Plotly[1].
# It can be removed after PR (https://github.com/plotly/plotly.py/pull/807) gets merged.
# [1] https://github.com/plotly/plotly.py/blob/master/plotly/figure_factory/_dendrogram.py
from collections import OrderedDict
from plotly import exceptions, optional_imports
from plotly.graph_objs import graph_objs
# Optional imports, may be None for users that only use our core functionality.
np = optional_imports.get_module('numpy')
scp = optional_imports.get_module('scipy')
sch = optional_imports.get_module('scipy.cluster.hierarchy')
scs = optional_imports.get_module('scipy.spatial')
def create_dendrogram(X, orientation="bottom", labels=None,
colorscale=None, distfun=None,
linkagefun=lambda x: sch.linkage(x, 'single'),
annotation=None):
"""
BETA function that returns a dendrogram Plotly figure object.
:param (ndarray) X: Matrix of observations as array of arrays
:param (str) orientation: 'top', 'right', 'bottom', or 'left'
:param (list) labels: List of axis category labels(observation labels)
:param (list) colorscale: Optional colorscale for dendrogram tree
:param (function) distfun: Function to compute the pairwise distance from
the observations
:param (function) linkagefun: Function to compute the linkage matrix from
the pairwise distances
clusters
Example 1: Simple bottom oriented dendrogram
```
import plotly.plotly as py
from plotly.figure_factory import create_dendrogram
import numpy as np
X = np.random.rand(10,10)
dendro = create_dendrogram(X)
plot_url = py.plot(dendro, filename='simple-dendrogram')
```
Example 2: Dendrogram to put on the left of the heatmap
```
import plotly.plotly as py
from plotly.figure_factory import create_dendrogram
import numpy as np
X = np.random.rand(5,5)
names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark']
dendro = create_dendrogram(X, orientation='right', labels=names)
dendro['layout'].update({'width':700, 'height':500})
py.iplot(dendro, filename='vertical-dendrogram')
```
Example 3: Dendrogram with Pandas
```
import plotly.plotly as py
from plotly.figure_factory import create_dendrogram
import numpy as np
import pandas as pd
Index= ['A','B','C','D','E','F','G','H','I','J']
df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index)
fig = create_dendrogram(df, labels=Index)
url = py.plot(fig, filename='pandas-dendrogram')
```
"""
if not scp or not scs or not sch:
raise ImportError("FigureFactory.create_dendrogram requires scipy, \
scipy.spatial and scipy.hierarchy")
s = X.shape
if len(s) != 2:
exceptions.PlotlyError("X should be 2-dimensional array.")
if distfun is None:
distfun = scs.distance.pdist
dendrogram = _Dendrogram(X, orientation, labels, colorscale,
distfun=distfun, linkagefun=linkagefun,
annotation=annotation)
return {'layout': dendrogram.layout,
'data': dendrogram.data}
class _Dendrogram(object):
"""Refer to FigureFactory.create_dendrogram() for docstring."""
def __init__(self, X, orientation='bottom', labels=None, colorscale=None,
width="100%", height="100%", xaxis='xaxis', yaxis='yaxis',
distfun=None,
linkagefun=lambda x: sch.linkage(x, 'single'),
annotation=None):
self.orientation = orientation
self.labels = labels
self.xaxis = xaxis
self.yaxis = yaxis
self.data = []
self.leaves = []
self.sign = {self.xaxis: 1, self.yaxis: 1}
self.layout = {self.xaxis: {}, self.yaxis: {}}
if self.orientation in ['left', 'bottom']:
self.sign[self.xaxis] = 1
else:
self.sign[self.xaxis] = -1
if self.orientation in ['right', 'bottom']:
self.sign[self.yaxis] = 1
else:
self.sign[self.yaxis] = -1
if distfun is None:
distfun = scs.distance.pdist
(dd_traces, xvals, yvals,
ordered_labels, leaves) = self.get_dendrogram_traces(X, colorscale, distfun, linkagefun, annotation)
self.labels = ordered_labels
self.leaves = leaves
yvals_flat = yvals.flatten()
xvals_flat = xvals.flatten()
self.zero_vals = []
for i in range(len(yvals_flat)):
if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals:
self.zero_vals.append(xvals_flat[i])
self.zero_vals.sort()
self.layout = self.set_figure_layout(width, height)
self.data = graph_objs.Data(dd_traces)
def get_color_dict(self, colorscale):
"""
Returns colorscale used for dendrogram tree clusters.
:param (list) colorscale: Colors to use for the plot in rgb format.
:rtype (dict): A dict of default colors mapped to the user colorscale.
"""
# These are the color codes returned for dendrograms
# We're replacing them with nicer colors
d = {'r': 'red',
'g': 'green',
'b': 'blue',
'c': 'cyan',
'm': 'magenta',
'y': 'yellow',
'k': 'black',
'w': 'white'}
default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0]))
if colorscale is None:
colorscale = [
'rgb(0,116,217)', # blue
'rgb(35,205,205)', # cyan
'rgb(61,153,112)', # green
'rgb(40,35,35)', # black
'rgb(133,20,75)', # magenta
'rgb(255,65,54)', # red
'rgb(255,255,255)', # white
'rgb(255,220,0)'] # yellow
for i in range(len(default_colors.keys())):
k = list(default_colors.keys())[i] # PY3 won't index keys
if i < len(colorscale):
default_colors[k] = colorscale[i]
return default_colors
def set_axis_layout(self, axis_key):
"""
Sets and returns default axis object for dendrogram figure.
:param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc.
:rtype (dict): An axis_key dictionary with set parameters.
"""
axis_defaults = {
'type': 'linear',
'ticks': 'outside',
'mirror': 'allticks',
'rangemode': 'tozero',
'showticklabels': True,
'zeroline': False,
'showgrid': False,
'showline': True,
}
if len(self.labels) != 0:
axis_key_labels = self.xaxis
if self.orientation in ['left', 'right']:
axis_key_labels = self.yaxis
if axis_key_labels not in self.layout:
self.layout[axis_key_labels] = {}
self.layout[axis_key_labels]['tickvals'] = \
[zv*self.sign[axis_key] for zv in self.zero_vals]
self.layout[axis_key_labels]['ticktext'] = self.labels
self.layout[axis_key_labels]['tickmode'] = 'array'
self.layout[axis_key].update(axis_defaults)
return self.layout[axis_key]
def set_figure_layout(self, width, height):
"""
Sets and returns default layout object for dendrogram figure.
"""
self.layout.update({
'showlegend': False,
'autosize': False,
'hovermode': 'closest',
'width': width,
'height': height
})
self.set_axis_layout(self.xaxis)
self.set_axis_layout(self.yaxis)
return self.layout
def get_dendrogram_traces(self, X, colorscale, distfun, linkagefun, annotation):
"""
Calculates all the elements needed for plotting a dendrogram.
:param (ndarray) X: Matrix of observations as array of arrays
:param (list) colorscale: Color scale for dendrogram tree clusters
:param (function) distfun: Function to compute the pairwise distance
from the observations
:param (function) linkagefun: Function to compute the linkage matrix
from the pairwise distances
:rtype (tuple): Contains all the traces in the following order:
(a) trace_list: List of Plotly trace objects for dendrogram tree
(b) icoord: All X points of the dendrogram tree as array of arrays
with length 4
(c) dcoord: All Y points of the dendrogram tree as array of arrays
with length 4
(d) ordered_labels: leaf labels in the order they are going to
appear on the plot
(e) P['leaves']: left-to-right traversal of the leaves
"""
d = distfun(X)
Z = linkagefun(d)
P = sch.dendrogram(Z, orientation=self.orientation,
labels=self.labels, no_plot=True)
icoord = scp.array(P['icoord'])
dcoord = scp.array(P['dcoord'])
ordered_labels = scp.array(P['ivl'])
color_list = scp.array(P['color_list'])
colors = self.get_color_dict(colorscale)
trace_list = []
for i in range(len(icoord)):
# xs and ys are arrays of 4 points that make up the '∩' shapes
# of the dendrogram tree
if self.orientation in ['top', 'bottom']:
xs = icoord[i]
else:
xs = dcoord[i]
if self.orientation in ['top', 'bottom']:
ys = dcoord[i]
else:
ys = icoord[i]
color_key = color_list[i]
text_annotation = None
if annotation:
text_annotation = annotation[i]
trace = graph_objs.Scatter(
x=np.multiply(self.sign[self.xaxis], xs),
y=np.multiply(self.sign[self.yaxis], ys),
mode='lines',
marker=graph_objs.Marker(color=colors[color_key]),
text=text_annotation,
hoverinfo='text'
)
try:
x_index = int(self.xaxis[-1])
except ValueError:
x_index = ''
try:
y_index = int(self.yaxis[-1])
except ValueError:
y_index = ''
trace['xaxis'] = 'x' + x_index
trace['yaxis'] = 'y' + y_index
trace_list.append(trace)
return trace_list, icoord, dcoord, ordered_labels, P['leaves']
from gensim.matutils import jensen_shannon
from scipy import spatial as scs
from scipy.cluster import hierarchy as sch
from scipy.spatial.distance import pdist, squareform
# get topic distributions
topic_dist = lda_fake.state.get_lambda()
# get topic terms
num_words = 300
topic_terms = [{w for (w, _) in lda_fake.show_topic(topic, topn=num_words)} for topic in range(topic_dist.shape[0])]
# no. of terms to display in annotation
n_ann_terms = 10
# use Jensen-Shannon distance metric in dendrogram
def js_dist(X):
return pdist(X, lambda u, v: jensen_shannon(u, v))
# calculate text annotations
def text_annotation(topic_dist, topic_terms, n_ann_terms):
# get dendrogram hierarchy data
linkagefun = lambda x: sch.linkage(x, 'single')
d = js_dist(topic_dist)
Z = linkagefun(d)
P = sch.dendrogram(Z, orientation="bottom", no_plot=True)
# store topic no.(leaves) corresponding to the x-ticks in dendrogram
x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10)
x_topic = dict(zip(P['leaves'], x_ticks))
# store {topic no.:topic terms}
topic_vals = dict()
for key, val in x_topic.items():
topic_vals[val] = (topic_terms[key], topic_terms[key])
text_annotations = []
# loop through every trace (scatter plot) in dendrogram
for trace in P['icoord']:
fst_topic = topic_vals[trace[0]]
scnd_topic = topic_vals[trace[2]]
# annotation for two ends of current trace
pos_tokens_t1 = list(fst_topic[0])[:min(len(fst_topic[0]), n_ann_terms)]
neg_tokens_t1 = list(fst_topic[1])[:min(len(fst_topic[1]), n_ann_terms)]
pos_tokens_t4 = list(scnd_topic[0])[:min(len(scnd_topic[0]), n_ann_terms)]
neg_tokens_t4 = list(scnd_topic[1])[:min(len(scnd_topic[1]), n_ann_terms)]
t1 = "<br>".join((": ".join(("+++", str(pos_tokens_t1))), ": ".join(("---", str(neg_tokens_t1)))))
t2 = t3 = ()
t4 = "<br>".join((": ".join(("+++", str(pos_tokens_t4))), ": ".join(("---", str(neg_tokens_t4)))))
# show topic terms in leaves
if trace[0] in x_ticks:
t1 = str(list(topic_vals[trace[0]][0])[:n_ann_terms])
if trace[2] in x_ticks:
t4 = str(list(topic_vals[trace[2]][0])[:n_ann_terms])
text_annotations.append([t1, t2, t3, t4])
# calculate intersecting/diff for upper level
intersecting = fst_topic[0] & scnd_topic[0]
different = fst_topic[0].symmetric_difference(scnd_topic[0])
center = (trace[0] + trace[2]) / 2
topic_vals[center] = (intersecting, different)
# remove trace value after it is annotated
topic_vals.pop(trace[0], None)
topic_vals.pop(trace[2], None)
return text_annotations
# get text annotations
annotation = text_annotation(topic_dist, topic_terms, n_ann_terms)
# Plot dendrogram
dendro = create_dendrogram(topic_dist, distfun=js_dist, labels=range(1, 36), annotation=annotation)
dendro['layout'].update({'width': 1000, 'height': 600})
py.iplot(dendro)
The x-axis or the leaves of hierarchy represent the topics of our LDA model, y-axis is a measure of closeness of either individual topics or their cluster. Essentially, the y-axis level at which the branches merge (relative to the "root" of the tree) is related to their similarity. For ex., topic 4 and 30 are more similar to each other than to topic 32. In addition, topic 18 and 24 are more similar to 35 than topic 4 and 30 are to topic 32 as the height on which they merge is lower than the merge height of 4/30 to 32.
Text annotations visible on hovering over the cluster nodes show the intersecting/different terms of it's two child nodes. Cluster node on first hierarchy level uses the topics on leaves directly to calculate intersecting/different terms, and the upper nodes assume the intersection(+++) as the topic terms of it's child node.
This type of tree graph could help us see the high level cluster theme that might exist in our data as we can see the common/different terms of combined topics in a cluster head annotation.
Now lets append the distance matrix of the topics below the dendrogram in form of heatmap so that we can see the exact distances between all pair of topics.
# get text annotations
annotation = text_annotation(topic_dist, topic_terms, n_ann_terms)
# Initialize figure by creating upper dendrogram
figure = create_dendrogram(topic_dist, distfun=js_dist, labels=range(1, 36), annotation=annotation)
for i in range(len(figure['data'])):
figure['data'][i]['yaxis'] = 'y2'
# get distance matrix and it's topic annotations
mdiff, annotation = lda_fake.diff(lda_fake, distance="jensen_shannon", normed=False)
# get reordered topic list
dendro_leaves = figure['layout']['xaxis']['ticktext']
dendro_leaves = [x - 1 for x in dendro_leaves]
# reorder distance matrix
heat_data = mdiff[dendro_leaves, :]
heat_data = heat_data[:, dendro_leaves]
# heatmap annotation
annotation_html = [["+++ {}<br>--- {}".format(", ".join(int_tokens), ", ".join(diff_tokens))
for (int_tokens, diff_tokens) in row] for row in annotation]
# plot heatmap of distance matrix
heatmap = go.Data([
go.Heatmap(
z=heat_data,
colorscale='YIGnBu',
text=annotation_html,
hoverinfo='x+y+z+text'
)
])
heatmap[0]['x'] = figure['layout']['xaxis']['tickvals']
heatmap[0]['y'] = figure['layout']['xaxis']['tickvals']
# Add Heatmap Data to Figure
figure['data'].extend(heatmap)
dendro_leaves = [x + 1 for x in dendro_leaves]
# Edit Layout
figure['layout'].update({'width': 800, 'height': 800,
'showlegend':False, 'hovermode': 'closest',
})
# Edit xaxis
figure['layout']['xaxis'].update({'domain': [.25, 1],
'mirror': False,
'showgrid': False,
'showline': False,
"showticklabels": True,
"tickmode": "array",
"ticktext": dendro_leaves,
"tickvals": figure['layout']['xaxis']['tickvals'],
'zeroline': False,
'ticks': ""})
# Edit yaxis
figure['layout']['yaxis'].update({'domain': [0, 0.75],
'mirror': False,
'showgrid': False,
'showline': False,
"showticklabels": True,
"tickmode": "array",
"ticktext": dendro_leaves,
"tickvals": figure['layout']['xaxis']['tickvals'],
'zeroline': False,
'ticks': ""})
# Edit yaxis2
figure['layout'].update({'yaxis2':{'domain': [0.75, 1],
'mirror': False,
'showgrid': False,
'showline': False,
'zeroline': False,
'showticklabels': False,
'ticks': ""}})
py.iplot(figure)
The heatmap lets us see the exact distance measure between any two topics in the z-value of their corresponding cell and also their intersecting or different terms in the +++/--- annotation. This could help see the distance between those topics also which are not directly connected in the dendrogram.