import numpy as np
from matplotlib import pyplot as plt
import jax.numpy as jnp
from jax import random, lax, jit, vmap
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.nested_sampling import NestedSampler
# Load data
lc = np.loadtxt("lc.txt")
interp = np.loadtxt("interp.txt")
t, F, Ferr = lc
points, zeta_e, zeta_n, zeta_e_dot, zeta_n_dot = interp
def model(t, F, Ferr):
ln_DeltaF = numpyro.sample("ln_DeltaF", dist.Normal(4.0, 4.0))
DeltaF = jnp.exp(ln_DeltaF)
ln_Fbase = numpyro.sample("ln_Fbase", dist.Normal(2.0, 4.0))
Fbase = jnp.exp(ln_Fbase)
t0 = numpyro.sample("t0", dist.Normal(3630.0, 50.0))
ln_tE = numpyro.sample("ln_tE", dist.Normal(3.0, 6.0))
tE = jnp.exp(ln_tE)
u0 = numpyro.sample("u0", dist.Normal(0.0, 1.0))
piEE = numpyro.sample("piEE", dist.Normal(0.0, 0.5))
piEN = numpyro.sample("piEN", dist.Normal(0.0, 0.5))
# Compute trajectory
zeta_e_t = jnp.interp(t, points, zeta_e)
zeta_n_t = jnp.interp(t, points, zeta_n)
zeta_e_t0 = jnp.interp(t0, points, zeta_e)
zeta_n_t0 = jnp.interp(t0, points, zeta_n)
zeta_e_dot_t0 = jnp.interp(t0, points, zeta_e_dot)
zeta_n_dot_t0 = jnp.interp(t0, points, zeta_n_dot)
delta_zeta_e = zeta_e_t - zeta_e_t0 - (t - t0) * zeta_e_dot_t0
delta_zeta_n = zeta_n_t - zeta_n_t0 - (t - t0) * zeta_n_dot_t0
u_per = u0 + piEN * delta_zeta_e - piEE * delta_zeta_n
u_par = (t - t0) / tE + piEE * delta_zeta_e + piEN * delta_zeta_n
u = jnp.sqrt(u_per ** 2 + u_par ** 2)
# Magnification
A_u = (u ** 2 + 2) / (u * jnp.sqrt(u ** 2 + 4))
A_u0 = (u0 ** 2 + 2) / (jnp.abs(u0) * jnp.sqrt(u0 ** 2 + 4))
A = (A_u - 1) / (A_u0 - 1)
F_pred = DeltaF * A + Fbase
ln_c = numpyro.sample("ln_c", dist.Exponential(1 / 2.0))
return numpyro.sample("data_dist", dist.Normal(F_pred, jnp.exp(ln_c) * Ferr), obs=F)
ns = NestedSampler(model, num_live_points=1000)
rng_key = random.PRNGKey(0)
ns.run(rng_key, t, F, Ferr)
ns.num_samples
65854
ns.print_summary()
ns.diagnostics(cornerplot=False) # plot_cornerplot is slow
mean std median 5.0% 95.0% n_eff r_hat ln_DeltaF 5.17 1.35 5.53 3.78 5.88 17.34 1.08 ln_Fbase 4.98 1.32 5.24 5.14 5.34 29.33 1.04 ln_c 2.28 0.76 2.01 1.87 3.13 7.05 1.27 ln_tE 4.13 2.55 4.13 3.12 9.47 102.16 1.00 piEE -0.04 0.26 -0.03 -0.44 0.35 42.76 1.00 piEN 0.35 0.38 0.51 -0.42 0.70 10.72 1.20 t0 3628.39 28.25 3631.88 3555.03 3643.77 55.18 1.02 u0 -0.47 0.59 -0.71 -0.98 0.67 13.01 1.17
fig, ax = plt.subplots(figsize=(6, 5))
ax.scatter(ns.get_samples()["u0"], ns.get_samples()["piEN"], color="black", alpha=0.1)
ax.set(xlim=(-0.8, 0.8), ylim=(-0.8, 0.8))
ax.set(xlabel=r"$u_0$", ylabel=r"$\pi_{EN}$")
ax.grid()
ns.print_summary()
mean std median 5.0% 95.0% n_eff r_hat ln_DeltaF 5.17 1.35 5.53 3.78 5.88 17.34 1.08 ln_Fbase 4.98 1.32 5.24 5.14 5.34 29.33 1.04 ln_c 2.28 0.76 2.01 1.87 3.13 7.05 1.27 ln_tE 4.13 2.55 4.13 3.12 9.47 102.16 1.00 piEE -0.04 0.26 -0.03 -0.44 0.35 42.76 1.00 piEN 0.35 0.38 0.51 -0.42 0.70 10.72 1.20 t0 3628.39 28.25 3631.88 3555.03 3643.77 55.18 1.02 u0 -0.47 0.59 -0.71 -0.98 0.67 13.01 1.17