Dirichlet Distribution, $Dir(\alpha)$

In this Jupyter notebook we generate a plot for Wikipedia, illustrating the graph of a few probability density functions for the Dirichlet distribution, corresponding to different parameter vectors $\alpha$.

In [1]:
import numpy as np
import matplotlib.tri as tri
import scipy.stats as st
import cmocean #http://matplotlib.org/cmocean/

We deal with the Dirichlet distribution defined on the open simplex $\{(x_1, x_2, x_2)\:|\: x_1+x_2+x_3=1, x_k\in(0,1)\}$.

$(x_1, x_2, x_3)$ are interpreted as the baricentric coordinates of the points in a planar triangle.

We take an equilateral triangle and subdivide it uniformly and recursively, by a procedure of type 1-to-4 split:

In [2]:
def cartesian2baric(verts,  point, dist=1.e-15):
    #converts 2d cartesian coordinates to baricentric coordinates with respect
    #to an equilateral triangle
    midpts = [(vertices[(i + 1) % 3] + vertices[(i + 2) % 3]) / 2.0  for i in range(3)] 
    baric = [np.dot(verts[i] - midpts[i], point - midpts[i]) / 0.75  for i in range(3)]
    return np.clip(baric, dist, 1.0 - dist)#clip coordinates to avoid points on the simplex boundary

def uniftriang(vertices, subdiv_level=7):
    #define a uniform triangulation of the triangle of vertices vertices
    triangle = tri.Triangulation(vertices[:, 0], vertices[:, 1])
    refined_tri = tri.UniformTriRefiner(triangle)
    finaltri = refined_tri.refine_triangulation(subdiv=subdiv_level)# final triangularization
    #finaltri.triangles are the simplices of the triangulation
    #finaltri.x, finaltri.y are the cartesian coordinates of the triangulation vertices 
    return finaltri

Define the vertices of an equilateral triangle, subdivide it, and compute the baricentric coordinates of the triangulation points:

In [3]:
vertices = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3)/2]])
triangul=uniftriang(vertices, subdiv_level=7)
baric_coords=[cartesian2baric(vertices,  point)  for point in zip(triangul.x, triangul.y)]

We plot a surface representing the Dirichlet probability density function as a trisurf. Below are the functions that define the coloring method and the trisurf as a Plotly Mesh3d object:

In [4]:
import plotly.plotly as py
from plotly.graph_objs import *
from plotly import tools as tls
In [23]:
def map_z2color(zval, colormap, vmin, vmax):
    #map the normalized value val to a corresponding color in the mpl colormap
    
    if vmin>=vmax:
        raise ValueError('incorrect relation between vmin and vmax')
    t=(zval-vmin)/float((vmax-vmin))#normalize val
    C=map(np.uint8, np.array(colormap(t)[:3])*255)
    #convert to a Plotly color code:
    return 'rgb'+str((C[0], C[1], C[2]))
In [24]:
def plotly_trisurf(x, y, z, simplices, colormap=cmocean.cm.bathy, scene='scene1'):
    #x, y, z are lists of coordinates of the triangle vertices 
    #simplices are the simplices that define the triangulation;
    #simplices  is a numpy array of shape (no_triangles, 3)
    #insert here the  type check for input data
    
    points3D=np.vstack((x,y,z)).T
    
    tri_vertices= points3D[simplices]# vertices of the surface triangles  
   
    zmean=tri_vertices[:, :, 2].mean(-1)# mean values of z-coordinates of the
                                        #triangle vertices
    
    min_zmean=np.min(zmean)
    max_zmean=np.max(zmean)  
   
    facecolor=[map_z2color(zz,  colormap, min_zmean, max_zmean) for zz in zmean] 
    I,J,K=zip(*simplices)
    
    triangles=Mesh3d(x=x,
                     y=y,
                     z=z,
                     facecolor=facecolor, 
                     i=I,
                     j=J,
                     k=K,
                     name=''
                    )
    
    
    
    return triangles

Define a list of parameters $\alpha$ for the Dirichlet distributions to be plotted:

In [16]:
alpha=[[(1.3, 1.3, 1.3), (3, 3, 3), (7, 7, 7)],
       [ (2,6,11), (14, 9, 5),  (6, 2, 6)]]
m=len(alpha)
n=len(alpha[0])
In [17]:
fig = tls.make_subplots(rows=m, cols=n, vertical_spacing=0.0075, horizontal_spacing=0.025,
                        specs=[ [{'is_3d': True}, {'is_3d': True}, {'is_3d': True}],
                                [{'is_3d': True}, {'is_3d': True}, {'is_3d': True}],
                              ],
                      
                       )
This is the format of your plot grid:
[ (1,1) scene1 ]  [ (1,2) scene2 ]  [ (1,3) scene3 ]
[ (2,1) scene4 ]  [ (2,2) scene5 ]  [ (2,3) scene6 ]

In [18]:
scenes=[['scene{}'.format(j+1+i*n) for j in range(n)] for i in range(m)]
In [19]:
axis = dict(
showbackground=True, 
backgroundcolor="rgb(230, 230,230)",
gridcolor="rgb(255, 255, 255)",      
zerolinecolor="rgb(255, 255, 255)", 
tickfont=dict(size=11)    
    )


scene=Scene(xaxis=XAxis(axis),
                     yaxis=YAxis(axis), 
                     zaxis=ZAxis(axis), 
                     aspectratio=dict(x=1,
                                      y=1,
                                      z=0.25
                                     )
            )
In [20]:
cmap=cmocean.cm.bathy

for i in range(m):
    for j in range(n):
        X=st.dirichlet(np.array(alpha[i][j]))
        C=[X.pdf(baric_coords[k]) for k in range(len(baric_coords)) ]
        zmax=max(C)
        trace=plotly_trisurf(triangul.x, triangul.y, C, triangul.triangles,  cmap, scene=scenes[i][j])
        fig.append_trace(trace, i+1, j+1)
        fig['layout'][scenes[i][j]].update(scene)
        fig['layout'][scenes[i][j]]['zaxis'].update(tickvals=[round(zmax/2,1), round(zmax,1)])
        
fig['layout'].update(title='Dirichlet distribution over an open 2-simplex'+
                     '<br> alpha=(1.3, 1.3, 1.3), (3, 3, 3), (7, 7, 7), '+
                     '<br>(2, 6, 11), (14, 9, 5), (6, 2, 6) ', 
                     font=dict(family='Georgia, serif',
                               size=14),
                     margin=dict(t=135),
                     height=900, 
                     width=1000, 
                     showlegend=False,
                    )
In [ ]:
py.sign_in('empet', 'my_api_key')
py.plot(fig, filename='Dirichlet-distr')
In [12]:
from IPython.display import HTML
HTML('<iframe src=https://plot.ly/~empet/13886/ width=900 height=700></iframe>')
Out[12]:
In [13]:
from IPython.core.display import HTML
def  css_styling():
    styles = open("./custom.css", "r").read()
    return HTML(styles)
css_styling()
Out[13]: