Dirichlet Distribution, $Dir(\alpha)$

In this Jupyter notebook we generate a plot for Wikipedia, illustrating the graph of a few probability density functions for the Dirichlet distribution, corresponding to different parameter vectors $\alpha$.

In [1]:
import numpy as np
import matplotlib.tri as tri
import scipy.stats as st

Dirichlet distribution is defined on the open simplex $\{(x_1, x_2, x_2)\:|\: x_1+x_2+x_3=1, x_k\in(0,1)\}$.

$(x_1, x_2, x_3)$ are interpreted as the baricentric coordinates of the points in a planar triangle.

We take an equilateral triangle and subdivide it uniformly and recursively, by a procedure of type 1-to-4 split:

In [2]:
def cartesian2baric(point, M, dist=1.e-15):
    # point is list or tuple of two floats representing the cartesia coordinates of a 2d point
    # M is the matrix of the transformation from cartesian to barycentric coordinates
   
    baric = np.dot(M, np.array(point +(1,))) 
    return np.clip(baric, dist, 1.0 - dist) # clip the baric to force it belong to the open simplex

def uniftriang(vertices, subdiv_level=7):
    #define a uniform triangulation of the triangle of vertices vertices
    triangle = tri.Triangulation(vertices[:, 0], vertices[:, 1])
    refined_tri = tri.UniformTriRefiner(triangle)
    finaltri = refined_tri.refine_triangulation(subdiv=subdiv_level)
    #finaltri.triangles are the simplices of the triangulation
    #finaltri.x, finaltri.y are the cartesian coordinates of the triangulation vertices 
    return finaltri

Define the vertices of an equilateral triangle, subdivide it, and compute the baricentric coordinates of the triangulation points:

In [3]:
tri_vertices = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3)/2]])
A = np.array([[0, 1, 0.5], [0, 0, np.sqrt(3)/2], [1, 1, 1]]) #transformation matrix from barycentric to cartesian coords
invA = np.linalg.inv(A)
triangul = uniftriang(tri_vertices)
baric_coords = [cartesian2baric(point, invA)  for point in zip(triangul.x, triangul.y)]

We define and plot the surface representing a Dirichlet probability density function as a Plotly Mesh3d:

In [4]:
import plotly.plotly as py
import plotly.graph_objs as go
from plotly import tools as tls
In [5]:
def plotly_triangular_mesh(vertices, simplices, intensities=None, colorscale="Viridis",
                           flatshading=False, showscale=False, reversescale=False, plot_edges=False):
    
    #vertices - vertices of the triangulation; a numpy array of shape (n_vertices, 3)
    #simplices - simplices (subtriangles) of the triangulation; a numpy array of shape (n_simplices, 3)
    #intensities can be either a function of (x,y,z) or a list of values; if it is None the intensity is z
    
    x, y, z = vertices.T 
    I, J, K = simplices.T     
    if intensities is None:
        intensity = z
    elif hasattr(intensities, '__call__'):
        intensity = intensities(x,y,z)
    elif  isinstance(intensities, (list, np.ndarray)):
        intensity = intensities #intensities are given in a list
    else:
        raise ValueError("intensities can be either a function or a list, np.array")
        
    return      dict(type='mesh3d',
                     x=x,
                     y=y,
                     z=z,
                     colorscale=colorscale, 
                     reversescale=reversescale,
                     intensity= intensity,
                     flatshading=flatshading,
                     i=I,
                     j=J,
                     k=K,
                     name='',
                     showscale=showscale
                    )   

Define a list of parameters $\alpha$ for the Dirichlet distributions to be plotted:

In [6]:
alpha = [[(1.3, 1.3, 1.3), (3, 3, 3), (7, 7, 7)],
         [(2, 6, 11), (14, 9, 5),  (6, 2, 6)]]
m = len(alpha)
n = len(alpha[0])
In [7]:
fig = tls.make_subplots(rows=m, cols=n, vertical_spacing=0.0075, horizontal_spacing=0.025,
                        specs=[ [{'is_3d': True}, {'is_3d': True}, {'is_3d': True}],
                                [{'is_3d': True}, {'is_3d': True}, {'is_3d': True}]
                              ])
This is the format of your plot grid:
[ (1,1) scene1 ]  [ (1,2) scene2 ]  [ (1,3) scene3 ]
[ (2,1) scene4 ]  [ (2,2) scene5 ]  [ (2,3) scene6 ]

In [8]:
scenes = [['scene{}'.format(j+1+i*n) for j in range(n)] for i in range(m)]
scenes
Out[8]:
[['scene1', 'scene2', 'scene3'], ['scene4', 'scene5', 'scene6']]
In [9]:
axis = dict(showbackground=True, 
            backgroundcolor="rgb(230, 230,230)",
            gridcolor="rgb(255, 255, 255)",      
            zerolinecolor="rgb(255, 255, 255)", 
            tickfont=dict(size=11),
           titlefont =dict(size=12))

scene = dict(xaxis=dict(axis),
             yaxis=dict(axis), 
             zaxis=dict(axis), 
             aspectratio=dict(x=1, y=1, z=0.25))
fig.update_scenes(scene);
In [10]:
pl_deep = [[0.0, 'rgb(253, 253, 204)'],
           [0.1, 'rgb(201, 235, 177)'],
           [0.2, 'rgb(145, 216, 163)'],
           [0.3, 'rgb(102, 194, 163)'],
           [0.4, 'rgb(81, 168, 162)'],
           [0.5, 'rgb(72, 141, 157)'],
           [0.6, 'rgb(64, 117, 152)'],
           [0.7, 'rgb(61, 90, 146)'],
           [0.8, 'rgb(65, 64, 123)'],
           [0.9, 'rgb(55, 44, 80)'],
           [1.0, 'rgb(39, 26, 44)']]
In [11]:
for i in range(m):
    for j in range(n):
        X = st.dirichlet(np.array(alpha[i][j]))
        C = [X.pdf(baric_coords[k]) for k in range(len(baric_coords)) ]
        zmax = max(C)
        surf_vertices = np.vstack((triangul.x, triangul.y, C)).T # vertices of the surface triangulation
        trace = plotly_triangular_mesh(surf_vertices, triangul.triangles, intensities=None, colorscale=pl_deep)
        fig.append_trace(trace, i+1, j+1)
        fig.update_scenes({'zaxis': {'tickvals': [round(zmax/2,1), round(zmax,1)]}})
        
fig.layout.update(title='Dirichlet distribution over an open 2-simplex'+
                     '<br> alpha=(1.3, 1.3, 1.3), (3, 3, 3), (7, 7, 7), '+
                     '<br>(2, 6, 11), (14, 9, 5), (6, 2, 6) ', 
                     font=dict(family='Georgia, serif',
                               size=14),
                     margin=dict(t=135),
                     height=800, 
                     width=900, 
                     showlegend=False
                    );
In [13]:
fw = go.FigureWidget(fig)
fw
In [14]:
from IPython.display import IFrame
IFrame('https://plot.ly/~empet/13886/',  width=900, height=800)
Out[14]:
In [ ]: