# MCMC: Gibbs sampling¶

Florent Leclercq,
Institute of Cosmology and Gravitation, University of Portsmouth,
[email protected]

In [1]:
import numpy as np
from scipy.stats import norm, gamma
from matplotlib import mlab as mlab
from matplotlib import pyplot as plt
from matplotlib import gridspec
from matplotlib.ticker import NullFormatter
from cycler import cycler
%matplotlib inline


## The target pdf¶

In [2]:
def target_joint(x,y):
return x*x * np.exp(-x*y*y -y*y +2.*y -4.*x)
def target_marginal_x(x):
return x*x/np.sqrt(x+1) * np.exp(-4.*x -1./(x+1.))
def target_marginal_y(y):
return np.exp(-y*y+2.*y) / (y*y+4.)**3
def target_conditional_given_y(y):
return gamma(a=3., scale=1./(y*y+4.))
def target_conditional_x_given_y(x,y):
return target_conditional_given_y(y).pdf(x)
def target_conditional_given_x(x):
return norm(loc=1./(x+1.), scale=np.sqrt(1./(2.*(x+1.))))
def target_conditional_y_given_x(y,x):
return target_conditional_given_x(x).pdf(y)

In [3]:
# Normalization of the marginals

In [4]:
# Marginals and conditionals peaking at one
def target_marginal_xN(x):
return target_marginal_x(x)/target_marginal_x(x).max()
def target_marginal_yN(y):
return target_marginal_y(y)/target_marginal_y(y).max()
def target_conditional_x_given_yN(x,y):
return target_conditional_x_given_y(x,y)/target_conditional_x_given_y(x,y).max()
def target_conditional_y_given_xN(y,x):
return target_conditional_y_given_x(y,x)/target_conditional_y_given_x(y,x).max()

In [5]:
xmin=0.
xmax=2.
ymin=-1.
ymax=2.5
x=np.linspace(xmin,xmax,1000)
y=np.linspace(ymin,ymax,1000)
X,Y=np.meshgrid(x,y)
Z=target_joint(X,Y)

In [6]:
nullfmt = NullFormatter()         # no labels

# definitions for the axes
left, width = 0., xmax-xmin
bottom, height = 0., ymax-ymin
left_h = left + width + 0.1
bottom_h = bottom + height + 0.1

rect_pdf = [left, bottom, width, height]
rect_pdfx = [left, bottom_h, width, 1.]
rect_pdfy = [left_h, bottom, 1., height]

plt.figure(1, figsize=(2, 2))

ax = plt.axes(rect_pdf)
axpdfx = plt.axes(rect_pdfx)
axpdfy = plt.axes(rect_pdfy)

# no labels
axpdfx.xaxis.set_major_formatter(nullfmt)
axpdfy.yaxis.set_major_formatter(nullfmt)

# the scatter plot:
ax.set_xlim(xmin,xmax)
ax.set_ylim(ymin,ymax)
ax.set_xlabel("$x$",size=15)
ax.set_ylabel("$y$",size=15)
ax.imshow(Z,extent=(xmin,xmax,ymin,ymax),origin='lower',cmap='YlGnBu')
ax.contour(X,Y,Z,extent=(xmin,xmax,ymin,ymax),origin='lower',linewidths=2,cmap='viridis_r')

axpdfx.set_xlim(xmin,xmax)
axpdfx.set_ylim(0,1.4)
axpdfx.plot(x,target_marginal_x(x)/Nx,linewidth=2.,color='darkorchid')

axpdfy.set_xlim(0,0.8)
axpdfy.set_ylim(ymin,ymax)
axpdfy.plot(target_marginal_y(y)/Ny,y,linewidth=2.,color='seagreen')

axpdfx.set_title("Joint pdf and marginals")
plt.show()

In [7]:
plt.rc('axes', prop_cycle=(cycler('color', [plt.cm.viridis(i) for i in np.linspace(0.1,0.9,11)])))
plt.ylim(0.,1.05)
plt.xlabel("$x$",size=15)
for this_y in np.linspace(ymin,ymax,11):
plt.plot(x,target_conditional_x_given_yN(x,y=this_y),label='y='+str(this_y),linewidth=2.)
plt.plot(x,target_marginal_xN(x),linewidth=2.,color='red',linestyle='--',label='marginal')
plt.title("Conditionals given $y$ and marginal")
plt.legend(frameon=False,loc='center left',bbox_to_anchor=(1, 0.5))
plt.show()

In [8]:
plt.rc('axes', prop_cycle=(cycler('color', [plt.cm.viridis(i) for i in np.linspace(0.1,0.9,11)])))
plt.ylim(0.,1.05)
plt.xlabel("$y$",size=15)
for this_x in np.linspace(xmin,xmax,11):
plt.plot(y,target_conditional_y_given_xN(y,x=this_x),label='x='+str(this_x),linewidth=2.)
plt.plot(y,target_marginal_yN(y),linewidth=2.,color='red',linestyle='--',label='marginal')
plt.title("Conditionals given $x$ and marginal")
plt.legend(frameon=False,loc='center left',bbox_to_anchor=(1, 0.5))
plt.show()


## Gibbs sampling¶

In [9]:
def Gibbs_sampler(target_conditional_given_x,target_conditional_given_y,Nsamp,x_start,y_start):
x=x_start
y=y_start
samples_x=[x]
samples_y=[y]
while len(samples_x)<Nsamp-1:
# first update x given y
x=target_conditional_given_y(y).rvs()
samples_x.append(x)
samples_y.append(y)
# then update y given x
y=target_conditional_given_x(x).rvs()
samples_x.append(x)
samples_y.append(y)
# last step, just update x given y
x=target_conditional_given_y(y).rvs()
samples_x.append(x)
samples_y.append(y)
return samples_x, samples_y

In [10]:
Nsamp=2000
x_start=1.8
y_start=-0.8
samples_x, samples_y=Gibbs_sampler(target_conditional_given_x,target_conditional_given_y,Nsamp,x_start,y_start)

In [11]:
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5))
ax1.set_xlim(0,Nsamp)
ax1.set_ylabel("$x$",size=15)
ax1.scatter(np.arange(Nsamp),samples_x,color='darkorchid',marker='.',linewidth=1.5)
ax2.set_xlim(0,Nsamp)
ax2.set_ylabel("$y$",size=15)
ax2.scatter(np.arange(Nsamp),samples_y,color='seagreen',marker='.',linewidth=1.5)
plt.show()

In [12]:
nullfmt = NullFormatter()         # no labels

# definitions for the axes
left, width = 0., xmax-xmin
bottom, height = 0., ymax-ymin
left_h = left + width + 0.1
bottom_h = bottom + height + 0.1

rect_pdf = [left, bottom, width, height]
rect_pdfx = [left, bottom_h, width, 1.]
rect_pdfy = [left_h, bottom, 1., height]

plt.figure(1, figsize=(2, 2))

ax = plt.axes(rect_pdf)
axpdfx = plt.axes(rect_pdfx)
axpdfy = plt.axes(rect_pdfy)

# no labels
axpdfx.xaxis.set_major_formatter(nullfmt)
axpdfy.yaxis.set_major_formatter(nullfmt)

# the scatter plot:
ax.set_xlim(xmin,xmax)
ax.set_ylim(ymin,ymax)
ax.set_xlabel("$x$",size=15)
ax.set_ylabel("$y$",size=15)
ax.imshow(Z,extent=(xmin,xmax,ymin,ymax),origin='lower',cmap='YlGnBu')
ax.contour(X,Y,Z,extent=(xmin,xmax,ymin,ymax),origin='lower',linewidths=2,cmap='viridis_r')
ax.scatter(samples_x,samples_y,marker='.',color='black')

axpdfx.set_xlim(xmin,xmax)
axpdfx.set_ylim(0,1.4)
axpdfx.plot(x,target_marginal_x(x)/Nx,linewidth=2.,color='darkorchid')
axpdfx.hist(samples_x,40,normed=True,histtype='step',color='darkorchid',linewidth=1.5)

axpdfy.set_xlim(0,0.8)
axpdfy.set_ylim(ymin,ymax)
axpdfy.plot(target_marginal_y(y)/Ny,y,linewidth=2.,color='seagreen')
axpdfy.hist(samples_y,40,normed=True,histtype='step',color='seagreen',linewidth=1.5,orientation='horizontal')

axpdfx.set_title("Gibbs sampling")
plt.show()