import math
import os
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import jax
import jaxlib
import jax.numpy as jnp
from jax import ops, random, vmap
from jax.scipy.special import expit
import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import print_summary, hpdi
from numpyro import handlers
from numpyro.distributions import constraints
from numpyro.infer import MCMC, NUTS, init_to_value
from numpyro.infer import Predictive
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az
import warnings
warnings.simplefilter("ignore")
plt.style.use('bmh')
rng_key = random.PRNGKey(0)
def sample(model, num_samples=1000, num_warmup=500, num_chains=2, **kwargs):
kernel = NUTS(model=model)
mcmc = MCMC(kernel,
num_samples=num_samples,
num_warmup=num_warmup,
num_chains=num_chains)
mcmc.run(random.PRNGKey(0) ,**kwargs)
samples = mcmc.get_samples()
return mcmc, samples
with numpyro.handlers.seed(rng_seed=123):
N = 200
# X
X_true = numpyro.sample('X_true', dist.Normal(loc=0., scale=3.).expand([N]))
sigma_X = numpyro.sample('sigma_X', dist.Normal(loc=0., scale=2.5).expand([N]))
X = X_true + sigma_X
# params
beta = jnp.array([4.5, -1.5])
# y
y_true = numpyro.sample('y_true', dist.Normal(loc= beta[0] * X_true + beta[1], scale=1.))
sigma_y = numpyro.sample('sigma_y', dist.Normal(loc=0, scale=2.5).expand([N]))
y = y_true + sigma_y
data = {}
data['X_true'] = X_true
data['y_true'] = y_true
data['X'] = X
data['sigma_X'] = jnp.fabs(sigma_X)
data['y'] = y
data['sigma_y'] = jnp.fabs(sigma_y)
def model_reg(params, X):
return params[0] * X + params[1]
def model_noerror(X, y=None):
beta = numpyro.sample('beta', dist.Normal(loc=0, scale=100).expand([2]))
sigma = numpyro.sample('sigma' ,dist.Exponential(1))
# model
mu = numpyro.deterministic('mu', model_reg(beta, X))
obs = numpyro.sample('obs', dist.Normal(loc=mu, scale=sigma), obs=y)
data_args = dict(X=data['X'], y=data['y'])
num_samples, num_warmup = 1000, 500
num_chains = 2
mcmc_ne, samples_ne = sample(model=model_noerror,
num_samples=num_samples,
num_warmup=num_warmup,
num_chains=num_chains,
**data_args)
az_data_noerror = az.from_numpyro(mcmc_ne)
az.summary(az_data_noerror, var_names=['beta', 'sigma'])
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_mean | ess_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|---|---|
beta[0] | 2.770 | 0.149 | 2.504 | 3.054 | 0.003 | 0.002 | 2021.0 | 2020.0 | 2010.0 | 1527.0 | 1.0 |
beta[1] | -1.783 | 0.586 | -2.958 | -0.746 | 0.014 | 0.010 | 1858.0 | 1813.0 | 1858.0 | 1588.0 | 1.0 |
sigma | 8.852 | 0.440 | 8.045 | 9.655 | 0.010 | 0.007 | 1830.0 | 1820.0 | 1838.0 | 1335.0 | 1.0 |
az.plot_trace(az_data_noerror, var_names=['beta', 'sigma']);
def model_xerror(X, sigma_X, y=None):
N = X.shape[0]
beta = numpyro.sample('beta', dist.Normal(loc=0., scale=100.).expand([2]))
sigma = numpyro.sample('sigma', dist.Exponential(1))
# true x
X_true_mu = numpyro.sample("X_true_mu", dist.Normal(loc=0., scale=100))
X_true_eta = numpyro.sample("X_true_eta", dist.HalfNormal(scale=20))
X_true = numpyro.sample('X_true', dist.Normal(loc=X_true_mu, scale=X_true_eta).expand([N]))
obs_X = numpyro.sample('obs_X', dist.Normal(loc=X_true, scale=sigma_X), obs=X)
# model
mu = numpyro.deterministic('mu', model_reg(beta, X_true))
obs = numpyro.sample('obs', dist.Normal(loc=mu, scale=sigma), obs=y)
data_args = dict(X=data['X'], sigma_X=data['sigma_X'], y=data['y'])
mcmc_xe, samples_xe = sample(model=model_xerror,
num_samples=num_samples,
num_warmup=num_warmup,
num_chains=num_chains,
**data_args)
az_data_xe = az.from_numpyro(mcmc_xe)
az.summary(az_data_xe, var_names=['beta', 'sigma'])
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_mean | ess_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|---|---|
beta[0] | 4.511 | 0.131 | 4.255 | 4.749 | 0.007 | 0.005 | 345.0 | 345.0 | 344.0 | 657.0 | 1.01 |
beta[1] | -2.063 | 0.377 | -2.705 | -1.293 | 0.016 | 0.012 | 527.0 | 527.0 | 531.0 | 793.0 | 1.00 |
sigma | 2.537 | 0.382 | 1.803 | 3.265 | 0.028 | 0.020 | 188.0 | 188.0 | 188.0 | 310.0 | 1.02 |
az.plot_trace(az_data_xe, var_names=['beta', 'sigma']);
_, ax = plt.subplots(figsize=(10, 4))
az.plot_forest([mcmc_ne, mcmc_xe],
var_names=['beta', 'sigma'],
model_names=['noError', 'xError'],
combined=True,
ax=ax);
def plot_data(data, ax = None, **kwargs):
if ax == None:
_, ax = plt.subplots(1, 1, figsize=(14, 6))
idx = jnp.argsort(data['X'])
x = data['X'][idx]
y = data['y'][idx]
ax.plot(x, y,
'kx',
alpha=0.5,
zorder=0);
ax.errorbar(x, y,
xerr=data['sigma_X'][idx],
yerr=data['sigma_y'][idx],
fmt='none',
color='k',
**kwargs
)
return ax
def plot_regression(X, y_mean, y_hpdi, ax=None, **kwargs):
if ax == None:
_, ax = plt.subplots(1, 1, figsize=(14, 6))
color = next(ax._get_lines.prop_cycler)['color']
idx = jnp.argsort(X)
x = X[idx]
ax.plot(x,
y_mean[idx],
'-',
lw=3,
color=color,
**kwargs
)
ax.fill_between(x,
y_hpdi[:, idx][0],
y_hpdi[:, idx][1],
color=color,
alpha=0.5,
)
return ax
# Compute empirical posterior distribution over mu
pred_fn = vmap(lambda samples: model_reg(samples, data['X']))
y_pred_ne = pred_fn(samples_ne['beta'])
y_pred_xe = pred_fn(samples_xe['beta'])
_, ax = plt.subplots(1, 1, figsize=(14, 6))
plot_data(data, ax, alpha=0.2)
ax.plot(data['X_true'], model_reg(beta, data['X_true']), '--', lw=4, label='real model generator')
plot_regression(data['X'], y_pred_ne.mean(0), hpdi(y_pred_ne, 0.9), ax, label='not error')
plot_regression(data['X'], y_pred_xe.mean(0), hpdi(y_pred_xe, 0.9), ax, label='x error')
ax.set(xlabel='x', ylabel='y');
plt.legend(loc=2)
plt.show()
rng_key, rng_key_ = random.split(rng_key)
predictions_ne = Predictive(model_noerror, samples_ne)(rng_key_, **dict(X=data['X']))['obs']
mean_pred_ne = jnp.mean(predictions_ne, axis=0)
hpdi_pred_ne = hpdi(predictions_ne, 0.9)
# ===========
predictions_xe = Predictive(model_xerror, samples_xe)(rng_key_,
**dict(X=data['X'],
sigma_X=data['sigma_X']))['obs']
mean_pred_xe = jnp.mean(predictions_xe, axis=0)
hpdi_pred_xe = hpdi(predictions_xe, 0.9)
_, ax = plt.subplots(1, 1, figsize=(14, 6))
plot_data(data, ax, alpha=0.2)
plot_regression(data['X'], mean_pred_ne, hpdi_pred_ne, ax, label='not error')
plot_regression(data['X'], mean_pred_xe, hpdi_pred_xe, ax, label='x error')
plt.legend(loc=2)
plt.show()
# predictions_ne = sample_posterior_predictive(model=model_noerror,
# trace=samples_ne,
# num_samples=num_samples, num_chains=num_chains,
# **dict(X=data['X'])
# )
# mean_pred_ne = jnp.mean(predictions_ne, axis=0)
# hpdi_pred_ne = hpdi(predictions_ne, 0.9)
# # ============
# predictions_xye = sample_posterior_predictive(model=model_xyerror,
# trace=samples_xye,
# num_samples=num_samples, num_chains=num_chains,
# **dict(X=data['X'],
# sigma_X=data['sigma_X']
# )
# )
# mean_pred_xye = jnp.mean(predictions_xye, axis=0)
# hpdi_pred_xye = hpdi(predictions_xye, 0.9)
%reload_ext watermark
%watermark -n -u -v -iv -w
jax 0.1.73 seaborn 0.10.1 numpy 1.19.3 pandas 1.1.4 numpyro 0.4.1 arviz 0.10.0 jaxlib 0.1.51 last updated: Wed Nov 11 2020 CPython 3.7.3 IPython 7.19.0 watermark 2.0.2