This Jupyter notebook accompanies the paper Manifold MCMC methods for Bayesian inference in a wide class of diffusion models, providing a complete runnable example of applying the method described in the paper to perform inference in an example hypoelliptic diffusion model.
We first check if the notebook is being run on Binder or Google Colab and if so install the sde
package and the other dependencies using pip
.
import os
ON_BINDER = 'BINDER_SERVICE_HOST' in os.environ
try:
import google.colab
ON_COLAB = True
except:
ON_COLAB = False
if ON_COLAB:
!pip install git+https://github.com/thiery-lab/manifold-mcmc-for-diffusions.git#egg=sde[notebook]
We now import the modules we will use to simulate from the model and perform inference
import mici
import sde.mici_extensions as mici_extensions
import symnum
import symnum.diffops.symbolic as diffops
import symnum.numpy as snp
import numpy as onp
import jax
from jax import lax, config, numpy as jnp
import matplotlib.pyplot as plt
import arviz
import corner
config.update('jax_enable_x64', True)
config.update('jax_platform_name', 'cpu')
We also set a dictionary of style parameters to use with Matplotlib plots
plot_style = {
'mathtext.fontset': 'cm',
'font.family': 'serif',
'axes.titlesize': 10,
'axes.labelsize': 10,
'xtick.labelsize': 6,
'ytick.labelsize': 6,
'legend.fontsize': 8,
'legend.frameon': False,
'axes.linewidth': 0.5,
'lines.linewidth': 0.5,
'axes.labelpad': 2.,
'figure.dpi': 150,
}
As an illustration we will consider the hypoelliptic diffusion defined by the system of stochastic differential equations (SDEs)
[dx0(τ)dx1(τ)]⏟dx(τ)=[1ϵ(x0(τ)−x0(τ)3−x1(τ))γx0(τ)−x1(τ)+β]⏟a(x(τ),z)dτ+[0σ]⏟B(x(τ),z)dw(τ)with x the X=R2-valued diffusion process of interest, w a univariate Wiener process and z=[σ;ϵ;γ;β]∈Z=R>0×R>0×R×R the model parameters.
This SDE system corresponds to a stochastic variant of the Fitzhugh-Nagumo model, a simplified description of actional potential generation within a neuronal axon.
We will use SymNum to symbolically define the drift a and diffusion coefficient B functions for the model in terms of the current state x and parameters z=[σ;ϵ;γ;β]. This will later allow us to automatically construct a function to numerical integrate the SDE system.
dim_x = 2
dim_w = 1
dim_z = 4
def drift_func(x, z):
σ, ε, γ, β = z
return snp.array([(x[0] - x[0]**3 - x[1]) / ε, γ * x[0] - x[1] + β])
def diff_coeff(x, z):
σ, ε, γ, β = z
return snp.array([[0], [σ]])
As in general exact simulation of the diffusion models of interest will be intractable, we define an approximate discrete time model based on numerical integration of the SDEs. Various numerical schemes for integrating SDE systems are available with varying convergence properties and implementational complexity - see for example Numerical Solutions of Stochastic Differential Equations (Kloden and Platen, 1992) for an in-depth survey.
The simplest and most common scheme is the Euler-Maruyama method (corresponding to a strong-order 0.5 Taylor approximation), which for a small time step δ>0 can be defined by a forward operator fδ:Z×X×V→X
fδ(z,x,v)=x+δa(x,z)+δ12B(x,z)vwhere v∈V is a vector of independent standard normal random variates of dimension equal to that of the Wiener process (here one).
The corresponding single step update can be defined using SymNum as:
def euler_maruyama_step(z, x, v, δ):
return x + δ * drift_func(x, z) + δ**0.5 * diff_coeff(x, z) @ v
More accurate approximations can be derived by using higher-order terms from the stochastic Taylor expansion of the SDE system. For example for a SDE model with additive noise, i.e. a diffusion coefficient B which is independent of the state B(x,z)=B(z), a strong order 1.5 Taylor scheme can be defined by the forward operator
fδ(z,x,[v1;v2])=x+δa(x,z)+δ22∂0a(x,z)a(x,z)+δ24[(tr(∂21ai(x,z)B(z)B(z)T))X−1i=0]+δ12B(z)v0+δ322∂1a(x,z)B(z)(v0+v1/√3)with both v1 and v2 having the dimension of the Wiener process and so the vector v=[v0;v1] twice the dimension of the Wiener process (therefore of dimension 2 here). This can be implemented using SymNum as follows
def strong_order_1p5_step(z, x, v, δ):
a = drift_func(x, z)
da_dx = diffops.jacobian(drift_func)(x, z)
B = diff_coeff(x, z)
dim_noise = B.shape[1]
d2a_dx2_BB = diffops.matrix_hessian_product(drift_func)(x, z)(B @ B.T)
v_1, v_2 = v[:dim_noise], v[dim_noise:]
return (
x + δ * a + (δ**2 / 2) * da_dx @ a + (δ**2 / 4) * d2a_dx2_BB +
δ**0.5 * B @ v_1 + (δ**1.5 / 2) * da_dx @ B @ (v_1 + v_2 / snp.sqrt(3)))
We can use these symbolically defined single step updates to define corresponding numerical functions which take NumPy arrays as inputs using SymNum's numpify
_func function. As well as the function to be transformed, the numpify_func
function requires the shape (dimensions) of all arguments to be specified. It also optionally allows specifying the module to use for the NumPy API calls with here we using the jax.numpy
module from JAX as this will allow us to later automatically construct efficient derivative functions for inference. Below we define a forward operator function using the strong order 1.5 step however we can instead use the Euler-Maruyma discretisation simply by setting the use_euler_maruyama
flag to True
.
use_euler_maruyama = False
if use_euler_maruyama:
forward_func = euler_maruyama_step
dim_v = dim_w
else:
forward_func = strong_order_1p5_step
dim_v = 2 * dim_w
forward_func = symnum.numpify_func(
forward_func, (dim_z,), (dim_x,), (dim_v,), None, numpy_module=jnp
)
Given a forward operator we can generate (approximate) samples of the state process at a series of discrete times. Here we assume that we use a fixed time increment δ>0 for all integrator steps and denote xs as the approximation to x(sδ).
As in the paper we assume the simple case that the diffusion process is discretely observed at T equally spaced times τt=tΔ ∀t∈1:T. We use a fixed number of steps S per interobservation interval with δ=ΔS so that the state at the tth observation time is xSt and the whole sequence of states to be simulated is x1:ST.
We assume the Y=1 dimensional observations y1:T correspond to direct observation of the first state component i.e. yt=ht(x)=x0 ∀t∈1:T.
def obs_func(x_seq):
return x_seq[..., 0:1]
As described in the paper we use a non-centered parameterisation of the generative model for the parameters z, time-discretised diffusion x0:ST and observations y1:T
We use priors x0∼N([−0.5;−0.5],I2), logσ∼N(−1,0.52), logϵ∼N(−2,0.52), γ∼N(1,0.52) and β∼N(1,0.52) which were roughly tuned so that with high probability state sequences x1:ST generated from the prior exhibited stable spiking dynamics and such that σ and ϵ obey their positivity constraints.
We reparameterise the parameters z and initial state x0 in terms of vectors of standard normal variates, respectively u and v0, with the parameter and initial state generator functions then set to gz(u)=[exp(0.5u0−1);exp(0.5u1−2);0.5u2+1;0.5u3+1] and gx0(v0,z)=[v0,0−0.5;v0,1−0.5] with input distributions ˜μ=N(0,I4) and ˜ν=N(0,I2). We can implement these generator functions in Python using JAX NumPy API functions (to allow us to later algorithmically differentiate through the generative model) as follows.
def generate_z(u):
"""Generate parameters from prior given an standard normal vector."""
return jnp.array([
jnp.exp(0.5 * u[0] - 1), # σ
jnp.exp(0.5 * u[1] - 2), # ϵ
0.5 * u[2] + 1, # γ
0.5 * u[3] + 1, # β
])
def generate_x_0(z, v_0):
"""Generate an initial state from prior given a standard normal vector."""
return jnp.array([-0.5, -0.5]) + v_0
The overall joint generative model for z, x0:ST and y1:T in terms of the independent and standard normal variates u and v0:ST can then be summarised
u∼N(0,I4)vs∼N(0,I2)∀s∈0:STz=gz(u)x0=gx0(v0,z)xs+1=fδ(z,xs,vs)∀s∈1:STyt=ht(xSt)∀t∈1:TWe collect all of the latent variables in to a 6+2ST dimensional flat vector q:=[u;v0:ST], with all components of q a priori independent and standard normal distributed, i.e. q∼N(0,I6+2ST).
A function to sample from the overall generative model given a latent input vector q can be implemented using the generate_z
, generate_x_0
and forward_func
functions and the JAX scan
operator (a differentiable loop / iterator construct) as follows
def generate_from_model(q, δ, dim_x, dim_z, dim_v, num_steps_per_obs):
"""Generate parameters and state + obs. sequences from model given q."""
u, v_0, v_r = jnp.split(q, (dim_z, dim_z + dim_x))
z = generate_z(u)
x_0 = generate_x_0(z, v_0)
v_seq = jnp.reshape(v_r, (-1, dim_v))
# Define integrator step function to scan:
# first argument is carried-forward state,
# second argument is input from scanned sequence.
def step_func(x, v):
x_n = forward_func(z, x, v, δ)
# Scan expects to return a tuple with the first element the carry-forward state
# and second element a slice of the output sequence (here also the state)
return x_n, x_n
# Scan step_func over the noise sequence v_seq initialising carry-forward with x_init
_, x_seq = lax.scan(step_func, x_0, v_seq)
y_seq = obs_func(x_seq[num_steps_per_obs - 1 :: num_steps_per_obs])
return x_seq, y_seq, z, x_0
In order to allow us to illustrate performing inference with the model, we first generate simulated observed data from the model itself, with the aim of then inferring the posterior distribution on the 'unknown' latent state given the simulated observations. We use T=100 observation times with interobservation interval Δ=0.5 and S=25 integrator steps per interobservation interval (δ=0.02) giving us an overall latent dimension of Q=5006 (we instead use T=20 and S=10 if running on Binder to reduce the CPU demand with then Q=406).
obs_interval = 0.5
if not ON_BINDER:
num_obs = 100
num_steps_per_obs = 25
else:
num_obs = 20
num_steps_per_obs = 10
num_steps = num_obs * num_steps_per_obs
dim_q = dim_z + dim_x + num_obs * num_steps_per_obs * dim_v
δ = obs_interval / num_steps_per_obs
We seed a NumPy RandomState
pseudo-random number generator object and use it to generate a latent input vector q from its standard normal prior.
seed = 20200710
rng = onp.random.RandomState(seed)
q_ref = rng.standard_normal(size=dim_q)
Using the previously defined generate_from_model
function we now generate simulated state and observation sequences, parameters and the initial state from the model given the just generated latent input vector.
x_seq_ref, y_seq_ref, z_ref, x_0_ref = generate_from_model(
q_ref, δ, dim_x, dim_z, dim_v, num_steps_per_obs
)
We can visualise the simulated state and observation sequences using Matplotlib. Below the blue and orange lines show the time courses of respectively the x0 and x1 state components, with the blue crosses indicating the simulated discrete time observations of the x0 component.
with plt.style.context(plot_style):
fig, axes = plt.subplots(2, sharex=True, figsize=(6, 4), dpi=150)
t_seq = (1 + onp.arange(num_steps)) * (num_obs) * obs_interval / num_steps
obs_indices = (1 + onp.arange(num_obs)) * num_steps_per_obs - 1
axes[0].plot(t_seq, x_seq_ref[:, 0], lw=0.5, color="C0")
axes[1].plot(t_seq, x_seq_ref[:, 1], lw=0.5, color="C1")
axes[0].plot(t_seq[obs_indices], y_seq_ref[:, 0], "x", ms=3, color="red")
axes[0].set_ylabel(r"$\mathsf{x}_{0}$")
axes[1].set_ylabel(r"$\mathsf{x}_{1}$")
for ax in axes:
ax.set_xlim(0, num_obs * obs_interval)
_ = axes[1].set_xlabel("Time $\\tau$")
fig.tight_layout()
To perform inference in the model given our simulated observed data, we use the manifold MCMC method implementations in the package Mici.
The key model-specific object required for inference in Mici is a Hamiltonian system instance. The Hamiltonian system encapsulates the various components of the Hamiltonian function for which the associated Hamiltonian dynamics are used as a proposal generating mechanism in a MCMC method. Mici includes various generic Hamiltonian system classes in the mici.systems
module corresponding to common cases such as (unconstrained) systems with Euclidean and Riemannian metrics and constrained Hamiltonian systems with a constraint function with dense Jacobian. Here we instead use a custom system class defined in the sde.mici_extensions
module which defines a constrained Hamiltonian system corresponding to a generative model for a diffusion as defined above (see Sections 3 and 4 in the paper). In particular our implementation exploits the sparsity induced in the Jacobian of the constraint function by artificially conditioning on the full state at a set of time points when sampling, as described in Section 5 in the paper. To construct an instance of this system class we pass in the variables defining the model dimensions defined earlier, the simulated observation sequence y_seq_ref
, the generated forward_func
implementing the strong-order 1.5 numerical integration scheme for the model, the generate_x_0
and generate_z
generator functions and obs_func
observation function. This class expects the passed functions to be defined using JAX primitives such as via calls to functions in the jax.numpy
module, so that it can use JAX's automatic differentiation primitives to automatically construct the required derivative functions.
num_obs_per_subseq = 5 # Number of obs in each fully conditioned subsequence
system = mici_extensions.ConditionedDiffusionConstrainedSystem(
obs_interval,
num_steps_per_obs,
num_obs_per_subseq,
y_seq_ref,
dim_z,
dim_x,
dim_v,
forward_func,
generate_x_0,
generate_z,
obs_func,
use_gaussian_splitting=True,
)
As well as the Hamiltonian system we also need to define an associated (symplectic) integrator, to numerically simulate the associated Hamiltonian dynamics. Here we use the mici.integrators.ConstrainedLeapfrogIntegrator
class, which corresponds to the constrained symplectic integrator described in Algorithm 1 in the paper (here we use the Gaussian specific Hamiltonian splitting described in Section 4.3.1 in the paper). We specify the tolerances on both the norm of the constraint equation constraint_tol
and the successive change in the position position_tol
for the Newton iteration used to solve the non-linear system of constraint equations, and also set a maximum number of iterations max_iters
. The tolerances for the reversibility check is set to 2 * position_tol
(motivated by the intuition that each of the forward and backward retraction / projection steps are solved to a position tolerance of position_tol
, so if the errors accumulate linearly the overall error in a reversible step should be less than 2 * position_tol
).
max_iters = 50 # Maximum number of quasi-Newton iterations in retraction solver
constraint_tol = 1e-9 # Convergence tolerance in constraint (observation) space
position_tol = 1e-8 # Convergence tolerance in position (latent) space
integrator = mici.integrators.ConstrainedLeapfrogIntegrator(
system,
projection_solver=mici_extensions.jitted_solve_projection_onto_manifold_newton,
reverse_check_tol=2 * position_tol,
projection_solver_kwargs=dict(
constraint_tol=constraint_tol, position_tol=position_tol, max_iters=max_iters
),
)
The final key object required for inference in Mici, is a MCMC sampler class instance. Here we use a MCMC method which sequentially applies three Markov transition kernels leaving the (extended) target distribution invariant on each iteration.
The first is a transition in which the momentum is independently resampled from its conditional distribution given the position (as described in Section 4.2 in the paper), as implemented by the mici.transitions.IndependentMomentumTransition
class. We could instead for example use an instance of mici.transitions.CorrelatedMomentumTransition
which would give to partial / correlated momentum resampling.
The second transition is the main Hamiltonian-dynamics driven transition which simulates the Hamiltonian dynamics associated with the passed system
object using the integrator
object to generate proposed moves. Here we use mici.transitions.MultinomialDynamicIntegrationTransition
, a dynamic integration time Hamiltonian Monte Carlo transition with multonimial sampling from the trajectory, analagous to the sampling algorithm used in the popular probabilistic programming framework Stan and as described in Appendix A in the article A conceptual introduction to Hamiltonian Monte Carlo (Betancourt, 2017).
The previous transition simulates the Hamiltonian dynamics for the conditioned diffusion system, i.e. full conditioning on a subset set of the states at the observation times. Therefore the third and final transition deterministically updates the set of observation time indices that are conditioned on in the Hamiltonian-dynamics integration based transition, here switching between two sets of observation time indices (partitions of the observation sequence) as descibed in Section 5 in the paper.
sampler = mici.samplers.MarkovChainMonteCarloMethod(
rng,
transitions={
"momentum": mici.transitions.IndependentMomentumTransition(system),
"integration": mici.transitions.MultinomialDynamicIntegrationTransition(
system, integrator
),
"switch_partition": mici_extensions.SwitchPartitionTransition(system),
},
)
To generate a set of initial states on satisfying the observation constraints, we use a linear interpolation based scheme. A set of parameters z and initial state x0 are sampled from their prior distributions and a sequence of diffusion states at the observation time indices ˜x1:T sampled consistent with the observed sequence y1:T (i.e. such that yt=ht(˜xt) ∀t∈1:T). The sequence of noise vectors v1:ST which maps to a state sequence x1:ST which linear interpolates between the states in ˜x1:T. This scheme requires that the forward function fδ is linear in the noise vector argument v and that the Jacobian of fδ with respect to v is full row-rank.
Due to the simple form of the observation function assumed here, to generate a diffusion state sequence ˜x1:T consistent with the observations y1:T we simply sample values for the x1 components from N(0,0.52) and set the x0 components values to the corresponding y1:T value. This is implemented in the function generate_x_obs_seq_init
below.
def generate_x_obs_seq_init(rng):
return jnp.concatenate((y_seq_ref, rng.standard_normal(y_seq_ref.shape) * 0.5), -1)
We now generate a list of initial states, one for each of the chains to be run, using a helper function find_initial_state_by_linear_interpolation
defined in the sde.mici_extensions
module which implements the scheme described above.
num_chains = 2 # Number of independent Markov chains to run
init_states = [
mici_extensions.find_initial_state_by_linear_interpolation(
system, rng, generate_x_obs_seq_init
)
for _ in range(num_chains)
]
As a final step before sampling the chains we define a function which outputs the variables to be traced (recorded) on each chain iteration.
def trace_func(state):
q = state.pos
u, v_0, v_seq = onp.split(q, (dim_z, dim_z + dim_x,))
v_seq = v_seq.reshape((-1, dim_v))
z = generate_z(u)
x_0 = generate_x_0(z, v_0)
return {"x_0": x_0, "σ": z[0], "ϵ": z[1], "γ": z[2], "β": z[3], "v_seq": v_seq}
We now use the constructed sampler
object to (sequentially) sample num_chains
Markov chains for num_warm_up_iter + n_main_iter
iterations. The first n_warm_up_iter
iterations are an adaptive warm up stage used to tune the integrator step size (to give a target acceptance statistic of 0.9) and are not used when calculating estimates / statistics using the chain samples. We specify for four statistics to be monitored during sampling - the average acceptance statistic (accept_stat
), proportion of integration transitions terminating due to non-convergence of the quasi-Newton iteration (convergence_error
), the proportion of integration transitions terminating due to detection of a non-reversible step (non_reversible_step
) and the number of integrator steps computed per transition (n_step
).
Due to the just-in-time compilation of the JAX model functions, the first couple of chain iterations will take longer as each of the model functions are compiled on their first calls (this happens for the first two rather than one iteration as the compiled model functions are specific to the partition / set of observation times conditioned on). During sampling, progress bars will be shown for each chain.
Note as sampling the chains puts a high demand on the CPU we default to sampling only very short chains if running on Binder to avoid creating excessive CPU load on their servers (chains will also run much slower on Binder servers due to the restricted CPU availabity). We recommend running longer chains on your local machine; the default settings of 2 chains of 1000 samples took approximately 15 minutes to run on the laptop used for testing.
if not ON_BINDER:
n_warm_up_iter = 250 # Number of chain samples in warm-up sampling phase
n_main_iter = 750 # Number of chain samples in main sampling phase
else:
n_warm_up_iter = 25
n_main_iter = 75
final_states, traces, stats = sampler.sample_chains_with_adaptive_warm_up(
n_warm_up_iter,
n_main_iter,
init_states,
trace_funcs=[trace_func],
adapters={
"integration": [
mici.adapters.DualAveragingStepSizeAdapter(
log_step_size_reg_coefficient=0.1
)
]
},
monitor_stats=[
("integration", "accept_stat"),
("integration", "convergence_error"),
("integration", "non_reversible_step"),
("integration", "n_step"),
],
)
Using ArviZ we can compute estimated effective sample sizes and split-ˆR convergence diagnostics for each of the traced variables (excluding the latent noise variables v_seq
= v1:ST due their large number). We find that the ˆR values for all the variables checked the estimated values are within the 1.01 threshold suggested in Rank-normalization, folding, and localization: An improved ˆR for assessing convergence of MCMC (Vehtari et al., 2019).
arviz.summary(traces, var_names=["σ", "ϵ", "γ", "β", "x_0"])
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
σ | 0.784 | 0.073 | 0.654 | 0.917 | 0.003 | 0.002 | 575.0 | 810.0 | 1.0 |
ϵ | 0.206 | 0.028 | 0.160 | 0.261 | 0.001 | 0.001 | 560.0 | 494.0 | 1.0 |
γ | 1.054 | 0.149 | 0.784 | 1.350 | 0.004 | 0.003 | 1284.0 | 1186.0 | 1.0 |
β | 0.282 | 0.110 | 0.071 | 0.478 | 0.002 | 0.002 | 2269.0 | 1008.0 | 1.0 |
x_0[0] | -0.202 | 0.970 | -2.068 | 1.604 | 0.023 | 0.027 | 1824.0 | 951.0 | 1.0 |
x_0[1] | 1.289 | 0.574 | 0.250 | 2.402 | 0.011 | 0.010 | 2677.0 | 1014.0 | 1.0 |
We can also use the Python package corner to visualise the pairwise posterior marginals estimated from the sampled chains.
num_var = dim_z + dim_x
with plt.style.context(plot_style):
fig, axes = plt.subplots(num_var, num_var, figsize=(1.5 * num_var, 1.5 * num_var))
_ = corner.corner(
traces,
var_names=["σ", "ϵ", "γ", "β", "x_0"],
truths=list(z_ref) + list(x_0_ref),
color="C0",
truth_color="C1",
show_titles=True,
smooth=1.0,
fig=fig,
)
for i in range(num_var):
for j in range(i + 1):
if i != num_var - 1:
axes[i, j].xaxis.set_ticks_position("none")
if j != 0 or (i == j == 0):
axes[i, j].yaxis.set_ticks_position("none")
We can also plot estimated pairwise posterior marginals for a subset of the latent noise vectors v1:ST from which we see that they remain close in distribution to their independent standard normal priors.
plot_indices = [(0, 0), (0, 1), (1000, 0), (1000, 1), (2000, 0), (2000, 1)]
num_indices = len(plot_indices)
v_seq_ref = q_ref[dim_x + dim_x :].reshape((-1, dim_v))
with plt.style.context(plot_style):
fig, axes = plt.subplots(
num_indices, num_indices, figsize=(1.5 * num_indices, 1.5 * num_indices)
)
_ = corner.corner(
traces,
var_names=["v_seq"],
coords={
"v_seq_dim_0": [i[0] for i in plot_indices],
"v_seq_dim_1": [i[1] for i in plot_indices],
},
truths=[v_seq_ref[i] for i in plot_indices],
color="C0",
truth_color="C1",
show_titles=True,
smooth=1.0,
fig=fig,
)
for i in range(num_indices):
for j in range(i + 1):
if i != num_var - 1:
axes[i, j].xaxis.set_ticks_position("none")
if j != 0 or (i == j == 0):
axes[i, j].yaxis.set_ticks_position("none")
Trace plots of the parameters and initial state suggest both chains converged to stationarity within the warm up stage.
with plt.style.context(plot_style):
arviz.plot_trace(
traces,
var_names=["σ", "ϵ", "γ", "β", "x_0"],
figsize=(9, 2 * (dim_x + dim_z)),
legend=True,
compact=False,
)