This notebook uses numpyro
and replicates experiments in references [1] which evaluates the performance of NUTS on various frameworks. The benchmark is run with CUDA 10.1 on a NVIDIA RTX 2070.
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
import time
import numpy as np
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import COVTYPE, load_dataset
from numpyro.infer import HMC, MCMC, NUTS
assert numpyro.__version__.startswith("0.8.0")
# NB: replace gpu by cpu to run this notebook in cpu
numpyro.set_platform("gpu")
We do preprocessing steps as in source code of reference [1]:
_, fetch = load_dataset(COVTYPE, shuffle=False)
features, labels = fetch()
# normalize features and add intercept
features = (features - features.mean(0)) / features.std(0)
features = jnp.hstack([features, jnp.ones((features.shape[0], 1))])
# make binary feature
_, counts = np.unique(labels, return_counts=True)
specific_category = jnp.argmax(counts)
labels = labels == specific_category
N, dim = features.shape
print("Data shape:", features.shape)
print(
"Label distribution: {} has label 1, {} has label 0".format(
labels.sum(), N - labels.sum()
)
)
Downloading - https://d2hg8soec8ck9v.cloudfront.net/datasets/covtype.zip. Download complete. Data shape: (581012, 55) Label distribution: 211840 has label 1, 369172 has label 0
Now, we construct the model:
def model(data, labels):
coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
logits = jnp.dot(data, coefs)
return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)
step_size = jnp.sqrt(0.5 / N)
kernel = HMC(
model,
step_size=step_size,
trajectory_length=(10 * step_size),
adapt_step_size=False,
)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500, progress_bar=False)
mcmc.warmup(random.PRNGKey(2019), features, labels, extra_fields=("num_steps",))
mcmc.get_extra_fields()["num_steps"].sum().copy()
tic = time.time()
mcmc.run(random.PRNGKey(2020), features, labels, extra_fields=["num_steps"])
num_leapfrogs = mcmc.get_extra_fields()["num_steps"].sum().copy()
toc = time.time()
print("number of leapfrog steps:", num_leapfrogs)
print("avg. time for each step :", (toc - tic) / num_leapfrogs)
mcmc.print_summary()
number of leapfrog steps: 5000 avg. time for each step : 0.0015881952285766601 mean std median 5.0% 95.0% n_eff r_hat coefs[0] 1.99 0.00 1.99 1.98 1.99 4.53 1.49 coefs[1] -0.03 0.00 -0.03 -0.03 -0.03 4.26 1.49 coefs[2] -0.12 0.00 -0.12 -0.12 -0.12 5.57 1.10 coefs[3] -0.29 0.00 -0.29 -0.29 -0.29 4.77 1.40 coefs[4] -0.09 0.00 -0.09 -0.10 -0.09 5.13 1.04 coefs[5] -0.15 0.00 -0.15 -0.15 -0.15 2.61 3.11 coefs[6] -0.02 0.00 -0.02 -0.02 -0.02 2.68 2.54 coefs[7] -0.50 0.00 -0.50 -0.50 -0.50 11.32 1.00 coefs[8] 0.27 0.00 0.27 0.27 0.27 3.25 2.03 coefs[9] -0.02 0.00 -0.02 -0.02 -0.02 6.34 1.42 coefs[10] -0.23 0.00 -0.23 -0.23 -0.22 3.76 1.50 coefs[11] -0.31 0.00 -0.31 -0.31 -0.31 3.51 1.40 coefs[12] -0.54 0.00 -0.54 -0.54 -0.54 2.64 2.52 coefs[13] -1.94 0.00 -1.94 -1.94 -1.93 2.54 2.75 coefs[14] 0.24 0.00 0.24 0.24 0.24 9.69 1.08 coefs[15] -1.07 0.00 -1.07 -1.07 -1.07 3.85 1.85 coefs[16] -1.26 0.00 -1.26 -1.26 -1.26 5.80 1.07 coefs[17] -0.22 0.00 -0.22 -0.22 -0.22 4.45 1.33 coefs[18] -0.08 0.00 -0.08 -0.08 -0.08 2.45 2.88 coefs[19] -0.68 0.00 -0.68 -0.69 -0.68 2.72 2.12 coefs[20] -0.13 0.00 -0.13 -0.13 -0.13 2.79 2.30 coefs[21] -0.02 0.00 -0.02 -0.02 -0.02 8.65 1.15 coefs[22] 0.02 0.00 0.02 0.02 0.02 2.73 2.32 coefs[23] -0.15 0.00 -0.15 -0.15 -0.15 2.75 2.56 coefs[24] -0.12 0.00 -0.12 -0.12 -0.12 3.92 1.31 coefs[25] -0.32 0.00 -0.32 -0.32 -0.32 5.25 1.31 coefs[26] -0.17 0.00 -0.17 -0.17 -0.17 4.08 1.13 coefs[27] -1.19 0.00 -1.19 -1.19 -1.19 3.22 1.85 coefs[28] -0.05 0.00 -0.05 -0.05 -0.05 7.87 1.01 coefs[29] -0.03 0.00 -0.03 -0.03 -0.03 7.36 1.17 coefs[30] -0.04 0.00 -0.04 -0.04 -0.04 2.88 2.06 coefs[31] -0.06 0.00 -0.06 -0.06 -0.06 6.43 1.23 coefs[32] -0.02 0.00 -0.02 -0.02 -0.02 6.80 1.03 coefs[33] -0.03 0.00 -0.03 -0.03 -0.03 6.47 1.26 coefs[34] 0.11 0.00 0.11 0.10 0.11 6.67 1.22 coefs[35] 0.08 0.00 0.08 0.08 0.08 2.49 2.80 coefs[36] -0.00 0.00 -0.00 -0.00 -0.00 6.23 1.31 coefs[37] -0.07 0.00 -0.07 -0.07 -0.07 2.72 2.36 coefs[38] -0.03 0.00 -0.03 -0.03 -0.03 3.97 1.52 coefs[39] -0.06 0.00 -0.06 -0.06 -0.06 6.16 1.26 coefs[40] -0.01 0.00 -0.01 -0.01 -0.01 2.86 2.07 coefs[41] -0.06 0.00 -0.06 -0.06 -0.06 3.02 1.98 coefs[42] -0.39 0.00 -0.39 -0.40 -0.39 2.67 2.45 coefs[43] -0.27 0.00 -0.27 -0.27 -0.27 5.15 1.33 coefs[44] -0.07 0.00 -0.07 -0.07 -0.07 5.75 1.30 coefs[45] -0.25 0.00 -0.25 -0.26 -0.25 2.57 2.50 coefs[46] -0.09 0.00 -0.09 -0.09 -0.09 8.72 1.00 coefs[47] -0.12 0.00 -0.12 -0.12 -0.12 3.10 1.73 coefs[48] -0.15 0.00 -0.15 -0.15 -0.15 4.95 1.33 coefs[49] -0.05 0.00 -0.05 -0.05 -0.05 2.99 2.32 coefs[50] -0.94 0.00 -0.94 -0.94 -0.94 10.08 1.00 coefs[51] -0.32 0.00 -0.32 -0.32 -0.32 3.90 1.75 coefs[52] -0.29 0.00 -0.29 -0.30 -0.29 13.85 1.05 coefs[53] -0.31 0.00 -0.31 -0.31 -0.31 8.21 1.01 coefs[54] -1.76 0.00 -1.76 -1.76 -1.76 3.24 1.54 Number of divergences: 0
In CPU, we get avg. time for each step : 0.02782863507270813
.
mcmc = MCMC(NUTS(model), num_warmup=50, num_samples=50, progress_bar=False)
mcmc.warmup(random.PRNGKey(2019), features, labels, extra_fields=("num_steps",))
mcmc.get_extra_fields()["num_steps"].sum().copy()
tic = time.time()
mcmc.run(random.PRNGKey(2020), features, labels, extra_fields=["num_steps"])
num_leapfrogs = mcmc.get_extra_fields()["num_steps"].sum().copy()
toc = time.time()
print("number of leapfrog steps:", num_leapfrogs)
print("avg. time for each step :", (toc - tic) / num_leapfrogs)
mcmc.print_summary()
number of leapfrog steps: 47406 avg. time for each step : 0.0022662237908313812 mean std median 5.0% 95.0% n_eff r_hat coefs[0] 1.97 0.01 1.97 1.95 1.98 74.56 1.05 coefs[1] -0.04 0.00 -0.04 -0.05 -0.03 59.26 0.99 coefs[2] -0.07 0.01 -0.06 -0.08 -0.05 35.80 1.12 coefs[3] -0.30 0.00 -0.30 -0.31 -0.29 54.31 1.00 coefs[4] -0.09 0.00 -0.09 -0.10 -0.09 38.45 0.99 coefs[5] -0.14 0.00 -0.14 -0.15 -0.14 26.25 1.12 coefs[6] 0.23 0.04 0.24 0.19 0.30 11.98 1.18 coefs[7] -0.65 0.02 -0.65 -0.69 -0.62 17.16 1.16 coefs[8] 0.57 0.04 0.57 0.48 0.62 12.71 1.18 coefs[9] -0.01 0.00 -0.01 -0.02 -0.01 58.92 0.99 coefs[10] 0.71 0.84 0.67 -0.76 2.04 7.17 0.98 coefs[11] 0.08 0.38 0.06 -0.57 0.68 7.18 0.98 coefs[12] 0.39 0.84 0.35 -1.09 1.72 7.18 0.98 coefs[13] -1.54 0.53 -1.56 -2.20 -0.65 10.23 0.99 coefs[14] -0.48 0.52 -0.45 -1.25 0.25 16.10 0.98 coefs[15] -1.83 0.31 -1.80 -2.34 -1.48 5.35 0.98 coefs[16] -1.06 0.52 -0.96 -1.88 -0.19 31.52 1.00 coefs[17] -0.17 0.08 -0.15 -0.30 -0.06 15.07 1.38 coefs[18] -0.64 0.64 -0.59 -1.50 0.25 18.98 1.03 coefs[19] -0.74 0.57 -0.71 -1.66 0.07 12.04 1.11 coefs[20] -1.04 0.64 -1.14 -1.80 -0.10 16.18 1.00 coefs[21] -0.01 0.01 -0.01 -0.02 0.01 12.68 1.42 coefs[22] 0.03 0.02 0.04 -0.00 0.07 15.54 1.37 coefs[23] -0.10 0.12 -0.07 -0.27 0.09 15.48 1.39 coefs[24] -0.09 0.08 -0.07 -0.21 0.02 15.48 1.36 coefs[25] -0.26 0.12 -0.24 -0.46 -0.10 15.62 1.37 coefs[26] -0.12 0.09 -0.10 -0.25 0.03 15.71 1.37 coefs[27] -1.11 0.47 -1.11 -1.83 -0.30 17.62 1.08 coefs[28] -0.83 0.70 -0.54 -2.04 0.02 34.06 0.99 coefs[29] -0.01 0.04 0.00 -0.06 0.05 15.94 1.36 coefs[30] -0.02 0.04 -0.00 -0.08 0.04 15.02 1.44 coefs[31] -0.05 0.03 -0.04 -0.09 0.00 16.46 1.28 coefs[32] 0.01 0.04 0.02 -0.06 0.07 15.28 1.36 coefs[33] 0.04 0.07 0.05 -0.06 0.14 15.73 1.37 coefs[34] 0.11 0.02 0.11 0.08 0.14 14.67 1.33 coefs[35] 0.13 0.12 0.16 -0.05 0.32 15.43 1.38 coefs[36] 0.07 0.16 0.11 -0.16 0.32 15.53 1.37 coefs[37] 0.00 0.10 0.02 -0.16 0.14 15.53 1.38 coefs[38] -0.04 0.02 -0.04 -0.06 -0.02 17.43 1.33 coefs[39] -0.05 0.04 -0.04 -0.10 0.01 15.25 1.40 coefs[40] 0.01 0.02 0.02 -0.02 0.05 15.66 1.35 coefs[41] -0.04 0.02 -0.04 -0.08 -0.00 11.32 1.38 coefs[42] -0.31 0.21 -0.26 -0.61 0.03 15.56 1.38 coefs[43] -0.20 0.12 -0.18 -0.40 -0.04 15.60 1.38 coefs[44] -0.01 0.11 0.02 -0.17 0.16 15.52 1.38 coefs[45] -0.15 0.15 -0.11 -0.37 0.09 15.46 1.38 coefs[46] -0.02 0.14 0.00 -0.23 0.20 15.83 1.37 coefs[47] -0.12 0.03 -0.11 -0.16 -0.07 16.20 1.38 coefs[48] -0.12 0.03 -0.12 -0.17 -0.08 16.26 1.36 coefs[49] -0.04 0.01 -0.04 -0.05 -0.03 14.31 1.28 coefs[50] -0.98 0.44 -0.94 -1.71 -0.33 12.09 0.98 coefs[51] -0.26 0.09 -0.24 -0.40 -0.14 15.53 1.38 coefs[52] -0.25 0.08 -0.23 -0.36 -0.12 15.81 1.37 coefs[53] -0.26 0.06 -0.25 -0.36 -0.16 15.99 1.36 coefs[54] -1.98 0.13 -1.96 -2.16 -1.81 44.87 0.98 Number of divergences: 0
In CPU, we get avg. time for each step : 0.028006251705287415
.
HMC | NUTS | |
---|---|---|
Edward2 (CPU) | 56.1 ms | |
Edward2 (GPU) | 9.4 ms | |
Pyro (CPU) | 35.4 ms | 35.3 ms |
Pyro (GPU) | 3.5 ms | 4.2 ms |
NumPyro (CPU) | 27.8 ms | 28.0 ms |
NumPyro (GPU) | 1.6 ms | 2.2 ms |
Note that in some situtation, HMC is slower than NUTS. The reason is the number of leapfrog steps in each HMC trajectory is fixed to $10$, while it is not fixed in NUTS.
Some takeaways: