MCMC: Slice sampling

Florent Leclercq,
Imperial Centre for Inference and Cosmology, Imperial College London,
[email protected]

In [1]:
import numpy as np
from scipy.integrate import quad
from scipy.stats import norm, cauchy
from math import pi
from matplotlib import pyplot as plt
from matplotlib.ticker import NullFormatter
np.random.seed(123456)
%matplotlib inline
plt.rcParams.update({'lines.linewidth': 2})

Slice sampling tutorial

In [2]:
gaussian=norm(0,1)
a=-5.
b=5.
x_arr=np.linspace(a,b,100)

1- Drawing $y$

Starting from an initial $x$, we draw $y$ uniformly in $[0,f(x)]$.

In [3]:
x=1.3
f_of_x=gaussian.pdf(x)
y=np.random.uniform(0.,f_of_x)
In [4]:
plt.figure(figsize=(10,6))
plt.xlim([a,b])
plt.plot(x_arr,gaussian.pdf(x_arr))
plt.plot([x,x],[0.,f_of_x],color='black')
plt.plot([a,b],[f_of_x,f_of_x],color='black')
plt.plot([a,b],[y,y],color='black',linestyle='--')
plt.fill_between([a,b],0.,f_of_x,facecolor='grey',alpha=0.3, linewidth=0.)
plt.ylim(bottom=0.)
plt.show()

2- Drawing a new $x$

We draw $x$ uniformly in the "slice" where $f(x)\geq y$. In the case of the Gaussian distribution $G(0,1)$, this is drawing $x$ uniformly in $[-x_0,x_0]$ where $x_0=\sqrt{-2 \ln(y \sqrt{2\pi})}$. This proposed sample is accepted or rejected according to the usual Metropolis-Hastings rule.

In [5]:
x0=np.sqrt(-2*np.log(y*np.sqrt(2*pi)))
x_new=np.random.uniform(-x0,x0)
In [6]:
x_slice_arr=np.linspace(-x0,x0,100)
plt.figure(figsize=(10,6))
plt.xlim([-5.,5.])
plt.plot(x_arr,gaussian.pdf(x_arr))
plt.plot([-5.,5.],[y,y],color='black')
plt.plot([x0,x0],[0.,y],color='black')
plt.plot([-x0,-x0],[0.,y],color='black')
plt.plot([x_new,x_new],[0.,y],color='black',linestyle='--')
plt.fill_between(x_slice_arr,0.,y,facecolor='grey',alpha=0.3, linewidth=0.)
plt.ylim(bottom=0.)
plt.show()

3- Iterate!

In [7]:
def slice_sampler_gaussian(Nsteps,x_start):
    Naccepted=0
    samples=np.zeros(Nsteps+1)
    samples[0]=x_start
    x=x_start
    for i in range(Nsteps):
        f_of_x=gaussian.pdf(x)
        y=np.random.uniform(0.,f_of_x)
        x0=np.sqrt(-2*np.log(y*np.sqrt(2*pi)))
        x_p=np.random.uniform(-x0,x0)
        a = min(1, gaussian.pdf(x_p)/gaussian.pdf(x))
        u = np.random.uniform()
        if u < a:
            Naccepted+=1
            x=x_p
        samples[i+1]=x
    return Naccepted,samples
In [8]:
x_start=2.
Nsteps=99
Naccepted,samples=slice_sampler_gaussian(Nsteps,x_start)
In [9]:
fraction_accepted=float(Naccepted)/Nsteps
fraction_accepted
Out[9]:
0.8383838383838383
In [10]:
plt.figure(figsize=(10,6))
plt.xlim([-5.,5.])
plt.plot(x_arr,gaussian.pdf(x_arr))
markerline, stemlines, baseline = plt.stem(samples,gaussian.pdf(samples),linefmt='-k',markerfmt='k.')
baseline.set_visible(False)
plt.title("Slice sampling")
plt.ylim(bottom=0.)
plt.show()

Slice sampling in practice and multimodal distributions

In [11]:
def target_pdf(x):
    return cauchy(scale=0.5,loc=0.8).pdf(x)+0.5*norm(2.8,0.3).pdf(x)
target_pdf=np.vectorize(target_pdf)
In [12]:
a=-2.
b=5.
x_arr=np.linspace(a,b,200)
f_arr=target_pdf(x_arr)
In [13]:
plt.figure(figsize=(10,6))
plt.xlim([a,b])
plt.plot(x_arr,f_arr,color='C2')
plt.title("Target pdf")
plt.ylim(bottom=0.)
plt.show()
In [14]:
def slice_sampler(target_pdf,Nsteps,x_start,x_width):
    Naccepted=0
    samples=np.zeros(Nsteps+1)
    samples[0]=x_start
    x=x_start
    for i in range(Nsteps):
        y=np.random.uniform(0, target_pdf(x))
        lb=x
        rb=x
        # we build the approximate slice by expanding around the current x
        while y<target_pdf(lb):
            lb-=x_width
        while y<target_pdf(rb):
            rb+=x_width
        # we draw a new x
        x_p=np.random.uniform(lb,rb)
        if target_pdf(x_p)>y:
            # x_p was in the slice, we keep it as a proposed sample
            # slice sampling satisfies detailed balance
            a = min(1, target_pdf(x_p)/target_pdf(x))
            u = np.random.uniform()
            if u < a:
                Naccepted+=1
                x=x_p
            samples[i+1]=x
        else:
            # x was not in the slice, we adjust the boundaries of the approximate slice
            if np.abs(x-lb)<np.abs(x-rb):
                lb = x
            else:
                rb = x
    return Naccepted,samples
In [15]:
x_width=0.5 # a good width to explore the multimodal distribution
Nsteps=150
x_start=3.2
Naccepted,samples=slice_sampler(target_pdf,Nsteps,x_start,x_width)
In [16]:
fraction_accepted=float(Naccepted)/Nsteps
fraction_accepted
Out[16]:
0.5733333333333334
In [17]:
plt.figure(figsize=(10,6))
plt.xlim([a,b])
plt.xlim(a,b)
plt.plot(x_arr,f_arr,color='C2')
markerline, stemlines, baseline = plt.stem(samples,target_pdf(samples),linefmt='-k',markerfmt='k.')
baseline.set_visible(False)
plt.title("Slice sampling")
plt.ylim(bottom=0.)
plt.show()

Two-dimensional slice sampling

The target pdf

In [18]:
def target_joint(x,y):
    return x*x * np.exp(-x*y*y -y*y +2.*y -4.*x)
def target_marginal_x(x):
    return x*x/np.sqrt(x+1) * np.exp(-4.*x -1./(x+1.))
def target_marginal_y(y):
    return np.exp(-y*y+2.*y) / (y*y+4.)**3
In [19]:
# Normalization of the marginals
Nx=quad(target_marginal_x,0.0001,100.)[0]
Ny=quad(target_marginal_y,-100.,100.)[0]
In [20]:
xmin=0.
xmax=2.
ymin=-1.
ymax=2.5
x=np.linspace(xmin,xmax,1000)
y=np.linspace(ymin,ymax,1000)
X,Y=np.meshgrid(x,y)
Z=target_joint(X,Y)
In [21]:
nullfmt = NullFormatter()         # no labels

# definitions for the axes
left, width = 0., xmax-xmin
bottom, height = 0., ymax-ymin
left_h = left + width + 0.1
bottom_h = bottom + height + 0.1

rect_pdf = [left, bottom, width, height]
rect_pdfx = [left, bottom_h, width, 1.]
rect_pdfy = [left_h, bottom, 1., height]

# start with a rectangular Figure
plt.figure(1, figsize=(2, 2))

ax = plt.axes(rect_pdf)
axpdfx = plt.axes(rect_pdfx)
axpdfy = plt.axes(rect_pdfy)

# no labels
axpdfx.xaxis.set_major_formatter(nullfmt)
axpdfy.yaxis.set_major_formatter(nullfmt)

# the scatter plot:
ax.set_xlim(xmin,xmax)
ax.set_ylim(ymin,ymax)
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.imshow(Z,extent=(xmin,xmax,ymin,ymax),origin='lower',cmap='YlGnBu')
ax.contour(X,Y,Z,extent=(xmin,xmax,ymin,ymax),origin='lower',cmap='viridis_r')

# the histograms:
axpdfx.set_xlim(xmin,xmax)
axpdfx.set_ylim([0,1.4])
axpdfx.plot(x,target_marginal_x(x)/Nx,color='C4')

axpdfy.set_xlim([0,0.8])
axpdfy.set_ylim(ymin,ymax)
axpdfy.plot(target_marginal_y(y)/Ny,y,color='C2')

axpdfx.set_title("Joint pdf and marginals")
plt.show()

The sampler

In [22]:
def slice_sampler_2D(target_pdf,Nsteps,x_start,xmin,xmax,y_start,ymin,ymax):
    Naccepted=0
    samples_x=np.zeros(Nsteps+1)
    samples_y=np.zeros(Nsteps+1)
    samples_x[0]=x_start
    samples_y[0]=y_start
    x=x_start
    y=y_start
    for i in range(Nsteps):
        z=np.random.uniform(0, target_pdf(x,y))
        
        # we draw a new (x,y) uniformly in the rectangle ([xmin,xmax],[ymin,ymax])
        # this may be inefficient. alternatively, one could adaptively define a 2D rectangle
        # (cf. hyperrectangle slice sampling)
        x_p=np.random.uniform(xmin,xmax)
        y_p=np.random.uniform(ymin,ymax)
        # we keep only points that are in the slice
        while target_pdf(x_p,y_p)<z:
            x_p=np.random.uniform(xmin,xmax)
            y_p=np.random.uniform(ymin,ymax)

        # (x_p,y_p) was in the slice, we keep it as a proposed sample
        # slice sampling satisfies detailed balance
        a = min(1, target_pdf(x_p,y_p)/target_pdf(x,y))
        u = np.random.uniform()
        if u < a:
            Naccepted+=1
            x=x_p
            y=y_p
        samples_x[i+1]=x
        samples_y[i+1]=y
    return Naccepted,samples_x,samples_y
In [23]:
Nsteps=2000
x_start=1.8
y_start=1.4
Naccepted,samples_x,samples_y=slice_sampler_2D(target_joint,Nsteps,x_start,xmin,xmax,y_start,ymin,ymax)
In [24]:
fraction_accepted=float(Naccepted)/Nsteps
fraction_accepted
Out[24]:
0.781
In [25]:
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5))
f.subplots_adjust(wspace=0.2)
ax1.set_xlim([0,Nsteps])
ax1.set_xlabel("sample number")
ax1.set_ylabel("$x$")
ax1.scatter(np.arange(Nsteps+1),samples_x,color='C4',marker='.')
ax2.set_xlim([0,Nsteps])
ax2.set_xlabel("sample number")
ax2.set_ylabel("$y$")
ax2.scatter(np.arange(Nsteps+1),samples_y,color='C2',marker='.')
plt.show()
In [26]:
nullfmt = NullFormatter()         # no labels

# definitions for the axes
left, width = 0., xmax-xmin
bottom, height = 0., ymax-ymin
left_h = left + width + 0.1
bottom_h = bottom + height + 0.1

rect_pdf = [left, bottom, width, height]
rect_pdfx = [left, bottom_h, width, 1.]
rect_pdfy = [left_h, bottom, 1., height]

# start with a rectangular Figure
plt.figure(1, figsize=(2, 2))

ax = plt.axes(rect_pdf)
axpdfx = plt.axes(rect_pdfx)
axpdfy = plt.axes(rect_pdfy)

# no labels
axpdfx.xaxis.set_major_formatter(nullfmt)
axpdfy.yaxis.set_major_formatter(nullfmt)

# the scatter plot:
ax.set_xlim(xmin,xmax)
ax.set_ylim(ymin,ymax)
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.imshow(Z,extent=(xmin,xmax,ymin,ymax),origin='lower',cmap='YlGnBu')
ax.contour(X,Y,Z,extent=(xmin,xmax,ymin,ymax),origin='lower',cmap='viridis_r')
ax.scatter(samples_x,samples_y,marker='.',color='black')

# the histograms:
axpdfx.set_xlim(xmin,xmax)
axpdfx.set_ylim(0,1.4)
axpdfx.plot(x,target_marginal_x(x)/Nx,color='C4')
axpdfx.hist(samples_x,40,density=True,histtype='step',color='C4',linewidth=2.)

axpdfy.set_xlim(0,0.8)
axpdfy.set_ylim(ymin,ymax)
axpdfy.plot(target_marginal_y(y)/Ny,y,color='C2')
axpdfy.hist(samples_y,40,density=True,histtype='step',color='C2',linewidth=2.,orientation='horizontal')

axpdfx.set_title("Slice sampling")
plt.show()