MCMC: Gibbs sampling

Florent Leclercq,
Institute of Cosmology and Gravitation, University of Portsmouth,
[email protected]

In [1]:
import numpy as np
from scipy.integrate import quad
from scipy.stats import norm, gamma
from matplotlib import mlab as mlab
from matplotlib import pyplot as plt
from matplotlib import gridspec
from matplotlib.ticker import NullFormatter
from cycler import cycler
%matplotlib inline

The target pdf

In [2]:
def target_joint(x,y):
    return x*x * np.exp(-x*y*y -y*y +2.*y -4.*x)
def target_marginal_x(x):
    return x*x/np.sqrt(x+1) * np.exp(-4.*x -1./(x+1.))
def target_marginal_y(y):
    return np.exp(-y*y+2.*y) / (y*y+4.)**3
def target_conditional_given_y(y):
    return gamma(a=3., scale=1./(y*y+4.))
def target_conditional_x_given_y(x,y):
    return target_conditional_given_y(y).pdf(x)
def target_conditional_given_x(x):
    return norm(loc=1./(x+1.), scale=np.sqrt(1./(2.*(x+1.))))
def target_conditional_y_given_x(y,x):
    return target_conditional_given_x(x).pdf(y)
In [3]:
# Normalization of the marginals
Nx=quad(target_marginal_x,0.0001,100.)[0]
Ny=quad(target_marginal_y,-100.,100.)[0]
In [4]:
# Marginals and conditionals peaking at one
def target_marginal_xN(x):
    return target_marginal_x(x)/target_marginal_x(x).max()
def target_marginal_yN(y):
    return target_marginal_y(y)/target_marginal_y(y).max()
def target_conditional_x_given_yN(x,y):
    return target_conditional_x_given_y(x,y)/target_conditional_x_given_y(x,y).max()
def target_conditional_y_given_xN(y,x):
    return target_conditional_y_given_x(y,x)/target_conditional_y_given_x(y,x).max()
In [5]:
xmin=0.
xmax=2.
ymin=-1.
ymax=2.5
x=np.linspace(xmin,xmax,1000)
y=np.linspace(ymin,ymax,1000)
X,Y=np.meshgrid(x,y)
Z=target_joint(X,Y)
In [6]:
nullfmt = NullFormatter()         # no labels

# definitions for the axes
left, width = 0., xmax-xmin
bottom, height = 0., ymax-ymin
left_h = left + width + 0.1
bottom_h = bottom + height + 0.1

rect_pdf = [left, bottom, width, height]
rect_pdfx = [left, bottom_h, width, 1.]
rect_pdfy = [left_h, bottom, 1., height]

# start with a rectangular Figure
plt.figure(1, figsize=(2, 2))

ax = plt.axes(rect_pdf)
axpdfx = plt.axes(rect_pdfx)
axpdfy = plt.axes(rect_pdfy)

# no labels
axpdfx.xaxis.set_major_formatter(nullfmt)
axpdfy.yaxis.set_major_formatter(nullfmt)

# the scatter plot:
ax.set_xlim(xmin,xmax)
ax.set_ylim(ymin,ymax)
ax.set_xlabel("$x$",size=15)
ax.set_ylabel("$y$",size=15)
ax.imshow(Z,extent=(xmin,xmax,ymin,ymax),origin='lower',cmap='YlGnBu')
ax.contour(X,Y,Z,extent=(xmin,xmax,ymin,ymax),origin='lower',linewidths=2,cmap='viridis_r')

axpdfx.set_xlim(xmin,xmax)
axpdfx.set_ylim(0,1.4)
axpdfx.plot(x,target_marginal_x(x)/Nx,linewidth=2.,color='darkorchid')

axpdfy.set_xlim(0,0.8)
axpdfy.set_ylim(ymin,ymax)
axpdfy.plot(target_marginal_y(y)/Ny,y,linewidth=2.,color='seagreen')

axpdfx.set_title("Joint pdf and marginals")
plt.show()
In [7]:
plt.rc('axes', prop_cycle=(cycler('color', [plt.cm.viridis(i) for i in np.linspace(0.1,0.9,11)])))
plt.ylim(0.,1.05)
plt.xlabel("$x$",size=15)
for this_y in np.linspace(ymin,ymax,11):
    plt.plot(x,target_conditional_x_given_yN(x,y=this_y),label='y='+str(this_y),linewidth=2.)
plt.plot(x,target_marginal_xN(x),linewidth=2.,color='red',linestyle='--',label='marginal')
plt.title("Conditionals given $y$ and marginal")
plt.legend(frameon=False,loc='center left',bbox_to_anchor=(1, 0.5))
plt.show()
In [8]:
plt.rc('axes', prop_cycle=(cycler('color', [plt.cm.viridis(i) for i in np.linspace(0.1,0.9,11)])))
plt.ylim(0.,1.05)
plt.xlabel("$y$",size=15)
for this_x in np.linspace(xmin,xmax,11):
    plt.plot(y,target_conditional_y_given_xN(y,x=this_x),label='x='+str(this_x),linewidth=2.)
plt.plot(y,target_marginal_yN(y),linewidth=2.,color='red',linestyle='--',label='marginal')
plt.title("Conditionals given $x$ and marginal")
plt.legend(frameon=False,loc='center left',bbox_to_anchor=(1, 0.5))
plt.show()

Gibbs sampling

In [9]:
def Gibbs_sampler(target_conditional_given_x,target_conditional_given_y,Nsamp,x_start,y_start):
    x=x_start
    y=y_start
    samples_x=[x]
    samples_y=[y]
    while len(samples_x)<Nsamp-1:
        # first update x given y
        x=target_conditional_given_y(y).rvs()
        samples_x.append(x)
        samples_y.append(y)
        # then update y given x
        y=target_conditional_given_x(x).rvs()
        samples_x.append(x)
        samples_y.append(y)
    # last step, just update x given y
    x=target_conditional_given_y(y).rvs()
    samples_x.append(x)
    samples_y.append(y)
    return samples_x, samples_y
In [10]:
Nsamp=2000
x_start=1.8
y_start=-0.8
samples_x, samples_y=Gibbs_sampler(target_conditional_given_x,target_conditional_given_y,Nsamp,x_start,y_start)
In [11]:
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5))
ax1.set_xlim(0,Nsamp)
ax1.set_ylabel("$x$",size=15)
ax1.scatter(np.arange(Nsamp),samples_x,color='darkorchid',marker='.',linewidth=1.5)
ax2.set_xlim(0,Nsamp)
ax2.set_ylabel("$y$",size=15)
ax2.scatter(np.arange(Nsamp),samples_y,color='seagreen',marker='.',linewidth=1.5)
plt.show()
In [12]:
nullfmt = NullFormatter()         # no labels

# definitions for the axes
left, width = 0., xmax-xmin
bottom, height = 0., ymax-ymin
left_h = left + width + 0.1
bottom_h = bottom + height + 0.1

rect_pdf = [left, bottom, width, height]
rect_pdfx = [left, bottom_h, width, 1.]
rect_pdfy = [left_h, bottom, 1., height]

# start with a rectangular Figure
plt.figure(1, figsize=(2, 2))

ax = plt.axes(rect_pdf)
axpdfx = plt.axes(rect_pdfx)
axpdfy = plt.axes(rect_pdfy)

# no labels
axpdfx.xaxis.set_major_formatter(nullfmt)
axpdfy.yaxis.set_major_formatter(nullfmt)

# the scatter plot:
ax.set_xlim(xmin,xmax)
ax.set_ylim(ymin,ymax)
ax.set_xlabel("$x$",size=15)
ax.set_ylabel("$y$",size=15)
ax.imshow(Z,extent=(xmin,xmax,ymin,ymax),origin='lower',cmap='YlGnBu')
ax.contour(X,Y,Z,extent=(xmin,xmax,ymin,ymax),origin='lower',linewidths=2,cmap='viridis_r')
ax.scatter(samples_x,samples_y,marker='.',color='black')

axpdfx.set_xlim(xmin,xmax)
axpdfx.set_ylim(0,1.4)
axpdfx.plot(x,target_marginal_x(x)/Nx,linewidth=2.,color='darkorchid')
axpdfx.hist(samples_x,40,normed=True,histtype='step',color='darkorchid',linewidth=1.5)

axpdfy.set_xlim(0,0.8)
axpdfy.set_ylim(ymin,ymax)
axpdfy.plot(target_marginal_y(y)/Ny,y,linewidth=2.,color='seagreen')
axpdfy.hist(samples_y,40,normed=True,histtype='step',color='seagreen',linewidth=1.5,orientation='horizontal')

axpdfx.set_title("Gibbs sampling")
plt.show()