In this notebook, we will learn how to visualize topic model using network graphs. Networks can be a great way to explore topic models. We can use it to navigate that how topics belonging to one context may relate to some topics in other context and discover common factors between them. We can use them to find communities of similar topics and pinpoint the most influential topic that has large no. of connections or perform any number of other workflows designed for network analysis.
from gensim.models.ldamodel import LdaModel
from gensim.corpora import Dictionary
import pandas as pd
import re
from gensim.parsing.preprocessing import remove_stopwords, strip_punctuation
import numpy as np
Using TensorFlow backend.
We'll use the fake news dataset from kaggle for this notebook. First step is to preprocess the data and train our topic model using LDA. You can refer to this notebook also for tips and suggestions of pre-processing the text data, and how to train LDA model for getting good results.
df_fake = pd.read_csv('fake.csv')
df_fake[['title', 'text', 'language']].head()
df_fake = df_fake.loc[(pd.notnull(df_fake.text)) & (df_fake.language=='english')]
# remove stopwords and punctuations
def preprocess(row):
return strip_punctuation(remove_stopwords(row.lower()))
df_fake['text'] = df_fake['text'].apply(preprocess)
# Convert data to required input format by LDA
texts = []
for line in df_fake.text:
lowered = line.lower()
words = re.findall(r'\w+', lowered, flags=re.UNICODE|re.LOCALE)
texts.append(words)
# Create a dictionary representation of the documents.
dictionary = Dictionary(texts)
# Filter out words that occur less than 2 documents, or more than 30% of the documents.
dictionary.filter_extremes(no_below=2, no_above=0.4)
# Bag-of-words representation of the documents.
corpus_fake = [dictionary.doc2bow(text) for text in texts]
lda_fake = LdaModel(corpus=corpus_fake, id2word=dictionary, num_topics=35, chunksize=1500, iterations=200, alpha='auto')
lda_fake.save('lda_35')
lda_fake = LdaModel.load('lda_35')
Firstly, a distance matrix is calculated to store distance between every topic pair. The nodes of the network graph will represent topics and the edges between them will be created based on the distance between two connecting nodes/topics.
# get topic distributions
topic_dist = lda_fake.state.get_lambda()
# get topic terms
num_words = 50
topic_terms = [{w for (w, _) in lda_fake.show_topic(topic, topn=num_words)} for topic in range(topic_dist.shape[0])]
To draw the edges, we can use different types of distance metrics available in gensim for calculating the distance between every topic pair. Next, we'd have to define a threshold of distance value such that the topic-pairs with distance above that does not get connected.
from scipy.spatial.distance import pdist, squareform
from gensim.matutils import jensen_shannon
import networkx as nx
import itertools as itt
# calculate distance matrix using the input distance metric
def distance(X, dist_metric):
return squareform(pdist(X, lambda u, v: dist_metric(u, v)))
topic_distance = distance(topic_dist, jensen_shannon)
# store edges b/w every topic pair along with their distance
edges = [(i, j, {'weight': topic_distance[i, j]})
for i, j in itt.combinations(range(topic_dist.shape[0]), 2)]
# keep edges with distance below the threshold value
k = np.percentile(np.array([e[2]['weight'] for e in edges]), 20)
edges = [e for e in edges if e[2]['weight'] < k]
Now that we have our edges, let's plot the annotated network graph. On hovering over the nodes, we'll see the topic_id along with it's top words and on hovering over the edges, we'll see the intersecting/different words of the two topics that it connects.
import plotly.offline as py
from plotly.graph_objs import *
py.init_notebook_mode()
# add nodes and edges to graph layout
G = nx.Graph()
G.add_nodes_from(range(topic_dist.shape[0]))
G.add_edges_from(edges)
graph_pos = nx.spring_layout(G)
# initialize traces for drawing nodes and edges
node_trace = Scatter(
x=[],
y=[],
text=[],
mode='markers',
hoverinfo='text',
marker=Marker(
showscale=True,
colorscale='YIGnBu',
reversescale=True,
color=[],
size=10,
colorbar=dict(
thickness=15,
xanchor='left'
),
line=dict(width=2)))
edge_trace = Scatter(
x=[],
y=[],
text=[],
line=Line(width=0.5, color='#888'),
hoverinfo='text',
mode='lines')
# no. of terms to display in annotation
n_ann_terms = 10
# add edge trace with annotations
for edge in G.edges():
x0, y0 = graph_pos[edge[0]]
x1, y1 = graph_pos[edge[1]]
pos_tokens = topic_terms[edge[0]] & topic_terms[edge[1]]
neg_tokens = topic_terms[edge[0]].symmetric_difference(topic_terms[edge[1]])
pos_tokens = list(pos_tokens)[:min(len(pos_tokens), n_ann_terms)]
neg_tokens = list(neg_tokens)[:min(len(neg_tokens), n_ann_terms)]
annotation = "<br>".join((": ".join(("+++", str(pos_tokens))), ": ".join(("---", str(neg_tokens)))))
x_trace = list(np.linspace(x0, x1, 10))
y_trace = list(np.linspace(y0, y1, 10))
text_annotation = [annotation] * 10
x_trace.append(None)
y_trace.append(None)
text_annotation.append(None)
edge_trace['x'] += x_trace
edge_trace['y'] += y_trace
edge_trace['text'] += text_annotation
# add node trace with annotations
for node in G.nodes():
x, y = graph_pos[node]
node_trace['x'].append(x)
node_trace['y'].append(y)
node_info = ''.join((str(node+1), ': ', str(list(topic_terms[node])[:n_ann_terms])))
node_trace['text'].append(node_info)
# color node according to no. of connections
for node, adjacencies in enumerate(G.adjacency_list()):
node_trace['marker']['color'].append(len(adjacencies))
fig = Figure(data=Data([edge_trace, node_trace]),
layout=Layout(showlegend=False,
hovermode='closest',
xaxis=XAxis(showgrid=True, zeroline=False, showticklabels=True),
yaxis=YAxis(showgrid=True, zeroline=False, showticklabels=True)))
py.iplot(fig)
For the above graph, we just used the 20th percentile of all the distance values. But we can experiment with few different values also such that the graph doesnโt become too crowded or too sparse and we could get an optimum amount of information about similar topics or any interesting relations b/w different topics.
Or we can also get an idea of threshold from the dendrogram (with โsingleโ linkage function). You can refer to this notebook for more details on topic dendrogram visualization. The y-values in the dendrogram represent the metric distances and if we choose a certain y-value then only those topics which are clustered below it would be connected. So let's plot the dendrogram now to see the sequential clustering process with increasing distance values.
# This input cell contains the modified code from Plotly[1].
# It can be removed after PR (https://github.com/plotly/plotly.py/pull/807) gets merged.
# [1] https://github.com/plotly/plotly.py/blob/master/plotly/figure_factory/_dendrogram.py
from collections import OrderedDict
from plotly import exceptions, optional_imports
from plotly.graph_objs import graph_objs
# Optional imports, may be None for users that only use our core functionality.
np = optional_imports.get_module('numpy')
scp = optional_imports.get_module('scipy')
sch = optional_imports.get_module('scipy.cluster.hierarchy')
scs = optional_imports.get_module('scipy.spatial')
def create_dendrogram(X, orientation="bottom", labels=None,
colorscale=None, distfun=None,
linkagefun=lambda x: sch.linkage(x, 'single'),
annotation=None):
"""
BETA function that returns a dendrogram Plotly figure object.
:param (ndarray) X: Matrix of observations as array of arrays
:param (str) orientation: 'top', 'right', 'bottom', or 'left'
:param (list) labels: List of axis category labels(observation labels)
:param (list) colorscale: Optional colorscale for dendrogram tree
:param (function) distfun: Function to compute the pairwise distance from
the observations
:param (function) linkagefun: Function to compute the linkage matrix from
the pairwise distances
clusters
Example 1: Simple bottom oriented dendrogram
```
import plotly.plotly as py
from plotly.figure_factory import create_dendrogram
import numpy as np
X = np.random.rand(10,10)
dendro = create_dendrogram(X)
plot_url = py.plot(dendro, filename='simple-dendrogram')
```
Example 2: Dendrogram to put on the left of the heatmap
```
import plotly.plotly as py
from plotly.figure_factory import create_dendrogram
import numpy as np
X = np.random.rand(5,5)
names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark']
dendro = create_dendrogram(X, orientation='right', labels=names)
dendro['layout'].update({'width':700, 'height':500})
py.iplot(dendro, filename='vertical-dendrogram')
```
Example 3: Dendrogram with Pandas
```
import plotly.plotly as py
from plotly.figure_factory import create_dendrogram
import numpy as np
import pandas as pd
Index= ['A','B','C','D','E','F','G','H','I','J']
df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index)
fig = create_dendrogram(df, labels=Index)
url = py.plot(fig, filename='pandas-dendrogram')
```
"""
if not scp or not scs or not sch:
raise ImportError("FigureFactory.create_dendrogram requires scipy, \
scipy.spatial and scipy.hierarchy")
s = X.shape
if len(s) != 2:
exceptions.PlotlyError("X should be 2-dimensional array.")
if distfun is None:
distfun = scs.distance.pdist
dendrogram = _Dendrogram(X, orientation, labels, colorscale,
distfun=distfun, linkagefun=linkagefun,
annotation=annotation)
return {'layout': dendrogram.layout,
'data': dendrogram.data}
class _Dendrogram(object):
"""Refer to FigureFactory.create_dendrogram() for docstring."""
def __init__(self, X, orientation='bottom', labels=None, colorscale=None,
width="100%", height="100%", xaxis='xaxis', yaxis='yaxis',
distfun=None,
linkagefun=lambda x: sch.linkage(x, 'single'),
annotation=None):
self.orientation = orientation
self.labels = labels
self.xaxis = xaxis
self.yaxis = yaxis
self.data = []
self.leaves = []
self.sign = {self.xaxis: 1, self.yaxis: 1}
self.layout = {self.xaxis: {}, self.yaxis: {}}
if self.orientation in ['left', 'bottom']:
self.sign[self.xaxis] = 1
else:
self.sign[self.xaxis] = -1
if self.orientation in ['right', 'bottom']:
self.sign[self.yaxis] = 1
else:
self.sign[self.yaxis] = -1
if distfun is None:
distfun = scs.distance.pdist
(dd_traces, xvals, yvals,
ordered_labels, leaves) = self.get_dendrogram_traces(X, colorscale, distfun, linkagefun, annotation)
self.labels = ordered_labels
self.leaves = leaves
yvals_flat = yvals.flatten()
xvals_flat = xvals.flatten()
self.zero_vals = []
for i in range(len(yvals_flat)):
if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals:
self.zero_vals.append(xvals_flat[i])
self.zero_vals.sort()
self.layout = self.set_figure_layout(width, height)
self.data = graph_objs.Data(dd_traces)
def get_color_dict(self, colorscale):
"""
Returns colorscale used for dendrogram tree clusters.
:param (list) colorscale: Colors to use for the plot in rgb format.
:rtype (dict): A dict of default colors mapped to the user colorscale.
"""
# These are the color codes returned for dendrograms
# We're replacing them with nicer colors
d = {'r': 'red',
'g': 'green',
'b': 'blue',
'c': 'cyan',
'm': 'magenta',
'y': 'yellow',
'k': 'black',
'w': 'white'}
default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0]))
if colorscale is None:
colorscale = [
'rgb(0,116,217)', # blue
'rgb(35,205,205)', # cyan
'rgb(61,153,112)', # green
'rgb(40,35,35)', # black
'rgb(133,20,75)', # magenta
'rgb(255,65,54)', # red
'rgb(255,255,255)', # white
'rgb(255,220,0)'] # yellow
for i in range(len(default_colors.keys())):
k = list(default_colors.keys())[i] # PY3 won't index keys
if i < len(colorscale):
default_colors[k] = colorscale[i]
return default_colors
def set_axis_layout(self, axis_key):
"""
Sets and returns default axis object for dendrogram figure.
:param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc.
:rtype (dict): An axis_key dictionary with set parameters.
"""
axis_defaults = {
'type': 'linear',
'ticks': 'outside',
'mirror': 'allticks',
'rangemode': 'tozero',
'showticklabels': True,
'zeroline': False,
'showgrid': False,
'showline': True,
}
if len(self.labels) != 0:
axis_key_labels = self.xaxis
if self.orientation in ['left', 'right']:
axis_key_labels = self.yaxis
if axis_key_labels not in self.layout:
self.layout[axis_key_labels] = {}
self.layout[axis_key_labels]['tickvals'] = \
[zv*self.sign[axis_key] for zv in self.zero_vals]
self.layout[axis_key_labels]['ticktext'] = self.labels
self.layout[axis_key_labels]['tickmode'] = 'array'
self.layout[axis_key].update(axis_defaults)
return self.layout[axis_key]
def set_figure_layout(self, width, height):
"""
Sets and returns default layout object for dendrogram figure.
"""
self.layout.update({
'showlegend': False,
'autosize': False,
'hovermode': 'closest',
'width': width,
'height': height
})
self.set_axis_layout(self.xaxis)
self.set_axis_layout(self.yaxis)
return self.layout
def get_dendrogram_traces(self, X, colorscale, distfun, linkagefun, annotation):
"""
Calculates all the elements needed for plotting a dendrogram.
:param (ndarray) X: Matrix of observations as array of arrays
:param (list) colorscale: Color scale for dendrogram tree clusters
:param (function) distfun: Function to compute the pairwise distance
from the observations
:param (function) linkagefun: Function to compute the linkage matrix
from the pairwise distances
:rtype (tuple): Contains all the traces in the following order:
(a) trace_list: List of Plotly trace objects for dendrogram tree
(b) icoord: All X points of the dendrogram tree as array of arrays
with length 4
(c) dcoord: All Y points of the dendrogram tree as array of arrays
with length 4
(d) ordered_labels: leaf labels in the order they are going to
appear on the plot
(e) P['leaves']: left-to-right traversal of the leaves
"""
d = distfun(X)
Z = linkagefun(d)
P = sch.dendrogram(Z, orientation=self.orientation,
labels=self.labels, no_plot=True)
icoord = scp.array(P['icoord'])
dcoord = scp.array(P['dcoord'])
ordered_labels = scp.array(P['ivl'])
color_list = scp.array(P['color_list'])
colors = self.get_color_dict(colorscale)
trace_list = []
for i in range(len(icoord)):
# xs and ys are arrays of 4 points that make up the 'โฉ' shapes
# of the dendrogram tree
if self.orientation in ['top', 'bottom']:
xs = icoord[i]
else:
xs = dcoord[i]
if self.orientation in ['top', 'bottom']:
ys = dcoord[i]
else:
ys = icoord[i]
color_key = color_list[i]
text_annotation = None
if annotation:
text_annotation = annotation[i]
trace = graph_objs.Scatter(
x=np.multiply(self.sign[self.xaxis], xs),
y=np.multiply(self.sign[self.yaxis], ys),
mode='lines',
marker=graph_objs.Marker(color=colors[color_key]),
text=text_annotation,
hoverinfo='text'
)
try:
x_index = int(self.xaxis[-1])
except ValueError:
x_index = ''
try:
y_index = int(self.yaxis[-1])
except ValueError:
y_index = ''
trace['xaxis'] = 'x' + x_index
trace['yaxis'] = 'y' + y_index
trace_list.append(trace)
return trace_list, icoord, dcoord, ordered_labels, P['leaves']
from gensim.matutils import jensen_shannon
import scipy as scp
from scipy.cluster import hierarchy as sch
from scipy import spatial as scs
# get topic distributions
topic_dist = lda_fake.state.get_lambda()
# get topic terms
num_words = 300
topic_terms = [{w for (w, _) in lda_fake.show_topic(topic, topn=num_words)} for topic in range(topic_dist.shape[0])]
# no. of terms to display in annotation
n_ann_terms = 10
# use Jenson-Shannon distance metric in dendrogram
def js_dist(X):
return pdist(X, lambda u, v: jensen_shannon(u, v))
# calculate text annotations
def text_annotation(topic_dist, topic_terms, n_ann_terms):
# get dendrogram hierarchy data
linkagefun = lambda x: sch.linkage(x, 'single')
d = js_dist(topic_dist)
Z = linkagefun(d)
P = sch.dendrogram(Z, orientation="bottom", no_plot=True)
# store topic no.(leaves) corresponding to the x-ticks in dendrogram
x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10)
x_topic = dict(zip(P['leaves'], x_ticks))
# store {topic no.:topic terms}
topic_vals = dict()
for key, val in x_topic.items():
topic_vals[val] = (topic_terms[key], topic_terms[key])
text_annotations = []
# loop through every trace (scatter plot) in dendrogram
for trace in P['icoord']:
fst_topic = topic_vals[trace[0]]
scnd_topic = topic_vals[trace[2]]
# annotation for two ends of current trace
pos_tokens_t1 = list(fst_topic[0])[:min(len(fst_topic[0]), n_ann_terms)]
neg_tokens_t1 = list(fst_topic[1])[:min(len(fst_topic[1]), n_ann_terms)]
pos_tokens_t4 = list(scnd_topic[0])[:min(len(scnd_topic[0]), n_ann_terms)]
neg_tokens_t4 = list(scnd_topic[1])[:min(len(scnd_topic[1]), n_ann_terms)]
t1 = "<br>".join((": ".join(("+++", str(pos_tokens_t1))), ": ".join(("---", str(neg_tokens_t1)))))
t2 = t3 = ()
t4 = "<br>".join((": ".join(("+++", str(pos_tokens_t4))), ": ".join(("---", str(neg_tokens_t4)))))
# show topic terms in leaves
if trace[0] in x_ticks:
t1 = str(list(topic_vals[trace[0]][0])[:n_ann_terms])
if trace[2] in x_ticks:
t4 = str(list(topic_vals[trace[2]][0])[:n_ann_terms])
text_annotations.append([t1, t2, t3, t4])
# calculate intersecting/diff for upper level
intersecting = fst_topic[0] & scnd_topic[0]
different = fst_topic[0].symmetric_difference(scnd_topic[0])
center = (trace[0] + trace[2]) / 2
topic_vals[center] = (intersecting, different)
# remove trace value after it is annotated
topic_vals.pop(trace[0], None)
topic_vals.pop(trace[2], None)
return text_annotations
# get text annotations
annotation = text_annotation(topic_dist, topic_terms, n_ann_terms)
# Plot dendrogram
dendro = create_dendrogram(topic_dist, distfun=js_dist, labels=range(1, 36), annotation=annotation)
dendro['layout'].update({'width': 1000, 'height': 600})
py.iplot(dendro)
From observing this dendrogram, we can try the threshold values between 0.3 to 0.35 for network graph, as the topics are clustered in distinct groups below them and this could plot separate clusters of related topics in the network graph.
But then why do we need to use network graph if the dendrogram already shows the topic clusters with a clear sequence of how topics joined one after the other. The problem is that we can't see the direct relation of any topic with another topic except if they are directly paired at the first hierarchy level. The network graph let's us explore the inter-topic distances and at the same time observe clusters of closely related topics.