EasyVVUQ provides support for MCMC sampling with multiple chains in parallel.
import os
import easyvvuq as uq
import numpy as np
import chaospy as cp
import json
import matplotlib.pyplot as plt
import sys
We define a Rosenbrock function in 2 dimensions for testing purposes. This will be a stand-in for our probability density.
def rosenbrock(directory):
json_input = os.path.join(directory, 'input.json')
if not os.path.isfile(json_input):
sys.exit(json_input + " does not exist.")
with open(json_input, "r") as fd:
inputs = json.load(fd)
x1 = float(inputs['x1'])
x2 = float(inputs['x2'])
output_filename = os.path.join(directory, inputs['outfile'])
y = (1.0 - x1) ** 2 + 100.0 * (x2 - x1 ** 2) ** 2
with open(output_filename, 'w') as fd:
json.dump({'value': -y}, fd)
Next we define a helper function to create a campaign, sample the search space and return the corresponding DataFrame.
def mcmc(tmp_path='.'):
campaign = uq.Campaign(name="mcmc", work_dir=tmp_path)
params = {
"x1": {"type": "float", "default": 0.0},
"x2": {"type": "float", "default": 0.0},
"out_file": {"type": "string", "default": "output.json"},
"chain_id": {"type": "integer", "default": 0}
}
encoder = uq.encoders.GenericEncoder(
template_fname=os.path.abspath("rosenbrock.template"), delimiter="$", target_filename="input.json")
decoder = uq.decoders.JSONDecoder("output.json", ["value"])
campaign.add_app(name="mcmc", params=params, encoder=encoder, decoder=decoder)
vary_init = {
"x1": [-1.0, 0.0, 1.0, 0.5, 0.1],
"x2": [1.0, 0.0, 0.5, 1.0, 0.2]
}
def q(x, b=1):
return cp.J(cp.Normal(x['x1'], b), cp.Normal(x['x2'], b))
sampler = uq.sampling.MCMCSampler(vary_init, q, 'value', n_chains=5)
campaign.set_sampler(sampler)
action = uq.actions.ExecutePython(rosenbrock)
iterator = campaign.iterate(action, mark_invalid=True)
for _ in range(1000):
next(iterator).start()
df = campaign.get_collation_result()
return df
df = mcmc()
Finally we plot the the five different chains.
import matplotlib.pyplot as plt
plt.plot(df[(df['chain_id'] == 0).values]['x1'], df[(df['chain_id'] == 0).values]['x2'], alpha=0.5)
plt.plot(df[(df['chain_id'] == 1).values]['x1'], df[(df['chain_id'] == 1).values]['x2'], alpha=0.5)
plt.plot(df[(df['chain_id'] == 2).values]['x1'], df[(df['chain_id'] == 2).values]['x2'], alpha=0.5)
plt.plot(df[(df['chain_id'] == 3).values]['x1'], df[(df['chain_id'] == 3).values]['x2'], alpha=0.5)
plt.plot(df[(df['chain_id'] == 4).values]['x1'], df[(df['chain_id'] == 4).values]['x2'], alpha=0.5)
[<matplotlib.lines.Line2D at 0x124a943a0>]
Finally let us plot a histogram of this data.
hist = plt.hist2d(df['x1'].T.values[0], df['x2'].T.values[0], bins=20)