Figure 2.29: (left) Overall negative log probability for the original model and the model with learned guess probabilities. The lower red bar indicates that learning the guess probabilities gives a substantially better model, according to this metric. (right) Negative log probability for each skill, showing that the improvement varies from skill to skill.
log_density
of the model it is the negative log probability of the ground truth. For a participant with $skill_i$ the negative log probability is .where $truth_i$ is an indicator variable of having $skill_i$ and the probability of each skill is $p(skill_i)$ ~ $Bernoulli(\theta_i)$
Further details from the text:
A common metric to use is the probability of the ground truth values under the inferred distributions. Sometimes it is convenient to take the logarithm of the probability, since this gives a more manageable number when the probability is very small. When we use the logarithm of the probability, the metric is referred to as the log probability. So, if the inferred probability of a person having a particular skill is $p$, then the log probability is $log(p)$ if the person has the skill and $log(1−p)$ if they don’t. If the person does have the skill then the best possible prediction is $p=1.0$, which gives log probability of $log(1.0)=0$ (the logarithm of one is zero). A less confident prediction, such as $p=0.8$ will give a log probability with a negative value, in this case $log(0.8)=−0.097$. The worst possible prediction of $p=0.0$ gives a log probability of negative infinity. ...
import operator
from functools import reduce
from typing import Callable, Dict, List
import arviz as az
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
from jax.scipy.special import logsumexp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs, log_likelihood
from numpyro.infer.util import log_density, potential_energy
%matplotlib inline
%reload_ext autoreload
%autoreload 2
%load_ext watermark
%watermark -v -m -p arviz,jax,matplotlib,numpy,pandas,scipy,numpyro
Python implementation: CPython Python version : 3.8.11 IPython version : 7.18.1 arviz : 0.11.2 jax : 0.2.19 matplotlib: 3.4.3 numpy : 1.20.3 pandas : 1.3.2 scipy : 1.6.2 numpyro : 0.7.2 Compiler : GCC 7.5.0 OS : Linux Release : 4.19.193-1-MANJARO Machine : x86_64 Processor : CPU cores : 4 Architecture: 64bit
%watermark -gb
Git hash: 307321cc497d1542d2908d60950823b102b16219 Git branch: master
def neg_log_proba_score(theta: np.array, y_true: np.array):
"""
Calculates the the negative log probability of the ground truth, the self assessed skills.
:param theta np.array: array of beta probabilities
:param y_true np.array, dtype == int: array of indicator variables for skill of participants
"""
assert theta.shape == y_true.shape
assert np.issubdtype(y_true.dtype, np.integer)
score = scipy.stats.bernoulli(theta).pmf(y_true)
score[score == 0.0] = np.finfo(float).eps
return -np.log(score)
def plot_bars(
data: np.array,
columns: List[str],
index: List[str],
ax=None,
tick_step=0.5,
**kwargs,
):
if ax is None:
fig, ax = plt.subplots()
else:
fig = None
pd.DataFrame(data, columns=columns, index=index).plot(
kind="bar", color=["b", "r"], ax=ax, zorder=3, **kwargs
)
ax.grid(zorder=0, axis="y")
ax.yaxis.set_ticks(np.arange(0, data.max(), tick_step));
signature: log_likelihood(model, posterior_samples, *args, parallel=False, batch_ndims=1, **kwargs)
log_likelihood
is from the Example: Baseball Batting Averagedef log_ppd(
model: Callable,
posterior_samples: Dict,
*args,
parallel=False,
batch_ndims=1,
**kwargs
):
"""
Log pointwise predictive density
:param model Callable: Python callable containing Pyro primitives
:param posterior_samples Dict: dictionary of samples from the posterior.
:param args: model arguments
:param parallel bool: passed to `log_likelihood` from numpyro.infer
:param batch_ndims Union[0, 1, 2]: passed to `log_likelihood` from numpyro.infer, see `log_likelihood` for details
:param kwargs: model kwargs
"""
post_loglik = log_likelihood(
model,
posterior_samples,
*args,
parallel=parallel,
batch_ndims=batch_ndims,
**kwargs
)
post_loglik_res = np.concatenate(
[obs[:, None] for obs in post_loglik.values()], axis=1
)
exp_log_density = logsumexp(post_loglik_res, axis=0) - jnp.log(
jnp.shape(post_loglik_res)[0]
)
return exp_log_density
rng_key = jax.random.PRNGKey(2)
raw_data = pd.read_csv(
"http://www.mbmlbook.com/Downloads/LearningSkills_Real_Data_Experiments-Original-Inputs-RawResponsesAsDictionary.csv"
)
self_assessed = raw_data.iloc[1:, 1:8].copy()
self_assessed = self_assessed.astype(int)
skills_key = pd.read_csv(
"http://www.mbmlbook.com/Downloads/LearningSkills_Real_Data_Experiments-Original-Inputs-Quiz-SkillsQuestionsMask.csv",
header=None,
)
skills_needed = []
for index, row in skills_key.iterrows():
skills_needed.append([i for i, x in enumerate(row) if x])
responses = pd.read_csv(
"http://www.mbmlbook.com/Downloads/LearningSkills_Real_Data_Experiments-Original-Inputs-IsCorrect.csv",
header=None,
)
responses = responses.astype("int32")
def model_00(
graded_responses, skills_needed: List[List[int]], prob_mistake=0.1, prob_guess=0.2
):
n_questions, n_participants = graded_responses.shape
n_skills = max(map(max, skills_needed)) + 1
participants_plate = numpyro.plate("participants_plate", n_participants)
with participants_plate:
skills = []
for s in range(n_skills):
skills.append(numpyro.sample("skill_{}".format(s), dist.Bernoulli(0.5)))
for q in range(n_questions):
has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])
prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess
isCorrect = numpyro.sample(
"isCorrect_{}".format(q),
dist.Bernoulli(prob_correct).to_event(1),
obs=graded_responses[q],
)
nuts_kernel = NUTS(model_00)
kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)
mcmc_00 = MCMC(
kernel, num_warmup=200, num_samples=1000, num_chains=4, jit_model_args=False
)
mcmc_00.run(
rng_key,
jnp.array(responses),
skills_needed,
extra_fields=(
"z",
"hmc_state.potential_energy",
"hmc_state.z",
"rng_key",
"hmc_state.rng_key",
),
)
mcmc_00.print_summary()
/home/benda/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/mcmc.py:269: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`. warnings.warn( sample: 100%|██████████| 1200/1200 [02:02<00:00, 9.83it/s, 1 steps of size 1.19e+37. acc. prob=1.00] sample: 100%|██████████| 1200/1200 [02:03<00:00, 9.70it/s, 1 steps of size 1.19e+37. acc. prob=1.00] sample: 100%|██████████| 1200/1200 [02:03<00:00, 9.73it/s, 1 steps of size 1.19e+37. acc. prob=1.00] sample: 100%|██████████| 1200/1200 [02:02<00:00, 9.83it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
mean std median 5.0% 95.0% n_eff r_hat skill_0[0] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[1] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[2] 0.01 0.10 0.00 0.00 0.00 4093.80 1.00 skill_0[3] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[4] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[5] 0.00 0.00 0.00 0.00 0.00 nan nan skill_0[6] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[7] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_0[8] 0.99 0.12 1.00 1.00 1.00 4130.52 1.00 skill_0[9] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[10] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_0[11] 0.97 0.18 1.00 1.00 1.00 4311.42 1.00 skill_0[12] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[13] 0.66 0.47 1.00 0.00 1.00 12611.01 1.00 skill_0[14] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_0[15] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[16] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[17] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[18] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[19] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[20] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[21] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[0] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[1] 1.00 0.07 1.00 1.00 1.00 4043.37 1.00 skill_1[2] 0.00 0.07 0.00 0.00 0.00 3397.09 1.00 skill_1[3] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[4] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[5] 0.00 0.06 0.00 0.00 0.00 4039.58 1.00 skill_1[6] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[7] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[8] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[9] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[10] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[11] 0.94 0.24 1.00 1.00 1.00 3595.70 1.00 skill_1[12] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_1[13] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_1[14] 0.97 0.17 1.00 1.00 1.00 3986.16 1.00 skill_1[15] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[16] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[17] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[18] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[19] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[20] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[21] 1.00 0.00 1.00 1.00 1.00 nan nan skill_2[0] 0.98 0.15 1.00 1.00 1.00 4195.02 1.00 skill_2[1] 0.59 0.49 1.00 0.00 1.00 27054.31 1.00 skill_2[2] 0.04 0.20 0.00 0.00 0.00 4370.98 1.00 skill_2[3] 0.58 0.49 1.00 0.00 1.00 27689.66 1.00 skill_2[4] 0.98 0.13 1.00 1.00 1.00 4140.87 1.00 skill_2[5] 0.00 0.00 0.00 0.00 0.00 nan nan skill_2[6] 0.98 0.14 1.00 1.00 1.00 3888.96 1.00 skill_2[7] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_2[8] 0.98 0.13 1.00 1.00 1.00 4056.76 1.00 skill_2[9] 0.04 0.20 0.00 0.00 0.00 4393.83 1.00 skill_2[10] 0.58 0.49 1.00 0.00 1.00 29403.65 1.00 skill_2[11] 0.98 0.14 1.00 1.00 1.00 3875.56 1.00 skill_2[12] 0.98 0.14 1.00 1.00 1.00 4166.72 1.00 skill_2[13] 0.98 0.15 1.00 1.00 1.00 4175.88 1.00 skill_2[14] 0.58 0.49 1.00 0.00 1.00 37696.08 1.00 skill_2[15] 0.59 0.49 1.00 0.00 1.00 24106.45 1.00 skill_2[16] 0.58 0.49 1.00 0.00 1.00 31530.27 1.00 skill_2[17] 0.98 0.14 1.00 1.00 1.00 4186.04 1.00 skill_2[18] 0.98 0.14 1.00 1.00 1.00 4171.18 1.00 skill_2[19] 1.00 0.00 1.00 1.00 1.00 nan nan skill_2[20] 0.98 0.13 1.00 1.00 1.00 4110.53 1.00 skill_2[21] 0.58 0.49 1.00 0.00 1.00 24811.19 1.00 skill_3[0] 0.99 0.08 1.00 1.00 1.00 3895.34 1.00 skill_3[1] 0.99 0.08 1.00 1.00 1.00 4060.72 1.00 skill_3[2] 0.99 0.09 1.00 1.00 1.00 3623.58 1.00 skill_3[3] 0.99 0.09 1.00 1.00 1.00 3235.95 1.00 skill_3[4] 0.78 0.41 1.00 0.00 1.00 6003.05 1.00 skill_3[5] 0.00 0.00 0.00 0.00 0.00 nan nan skill_3[6] 0.99 0.10 1.00 1.00 1.00 4033.33 1.00 skill_3[7] 0.78 0.42 1.00 0.00 1.00 7574.41 1.00 skill_3[8] 1.00 0.00 1.00 1.00 1.00 nan nan skill_3[9] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_3[10] 1.00 0.00 1.00 1.00 1.00 nan nan skill_3[11] 0.00 0.04 0.00 0.00 0.00 4022.87 1.00 skill_3[12] 0.99 0.09 1.00 1.00 1.00 3973.96 1.00 skill_3[13] 0.99 0.09 1.00 1.00 1.00 3833.68 1.00 skill_3[14] 0.99 0.09 1.00 1.00 1.00 4075.71 1.00 skill_3[15] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_3[16] 0.99 0.08 1.00 1.00 1.00 4055.73 1.00 skill_3[17] 0.99 0.09 1.00 1.00 1.00 4076.82 1.00 skill_3[18] 1.00 0.00 1.00 1.00 1.00 nan nan skill_3[19] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_3[20] 1.00 0.00 1.00 1.00 1.00 nan nan skill_3[21] 1.00 0.00 1.00 1.00 1.00 nan nan skill_4[0] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_4[1] 0.15 0.35 0.00 0.00 1.00 5860.67 1.00 skill_4[2] 0.59 0.49 1.00 0.00 1.00 21820.18 1.00 skill_4[3] 0.15 0.36 0.00 0.00 1.00 5949.44 1.00 skill_4[4] 1.00 0.00 1.00 1.00 1.00 nan nan skill_4[5] 0.00 0.03 0.00 0.00 0.00 nan 1.00 skill_4[6] 1.00 0.06 1.00 1.00 1.00 4036.51 1.00 skill_4[7] 0.15 0.36 0.00 0.00 1.00 5670.10 1.00 skill_4[8] 1.00 0.00 1.00 1.00 1.00 nan nan skill_4[9] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_4[10] 0.86 0.35 1.00 0.00 1.00 5786.31 1.00 skill_4[11] 1.00 0.07 1.00 1.00 1.00 3728.74 1.00 skill_4[12] 0.87 0.34 1.00 0.00 1.00 4455.87 1.00 skill_4[13] 0.15 0.36 0.00 0.00 1.00 5889.51 1.00 skill_4[14] 0.86 0.34 1.00 0.00 1.00 5095.11 1.00 skill_4[15] 0.86 0.35 1.00 0.00 1.00 5561.46 1.00 skill_4[16] 1.00 0.07 1.00 1.00 1.00 4041.21 1.00 skill_4[17] 1.00 0.06 1.00 1.00 1.00 4035.86 1.00 skill_4[18] 0.86 0.34 1.00 0.00 1.00 5361.01 1.00 skill_4[19] 0.99 0.08 1.00 1.00 1.00 4053.58 1.00 skill_4[20] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_4[21] 1.00 0.06 1.00 1.00 1.00 4035.14 1.00 skill_5[0] 1.00 0.00 1.00 1.00 1.00 nan nan skill_5[1] 0.00 0.00 0.00 0.00 0.00 nan nan skill_5[2] 0.78 0.42 1.00 0.00 1.00 7042.74 1.00 skill_5[3] 0.99 0.09 1.00 1.00 1.00 4074.01 1.00 skill_5[4] 0.99 0.08 1.00 1.00 1.00 3818.77 1.00 skill_5[5] 0.00 0.00 0.00 0.00 0.00 nan nan skill_5[6] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_5[7] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_5[8] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_5[9] 0.99 0.08 1.00 1.00 1.00 3627.43 1.00 skill_5[10] 0.78 0.41 1.00 0.00 1.00 7519.46 1.00 skill_5[11] 0.00 0.02 0.00 0.00 0.00 nan 1.00 skill_5[12] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_5[13] 0.78 0.42 1.00 0.00 1.00 7063.90 1.00 skill_5[14] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_5[15] 1.00 0.00 1.00 1.00 1.00 nan nan skill_5[16] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_5[17] 0.78 0.41 1.00 0.00 1.00 7747.84 1.00 skill_5[18] 0.78 0.42 1.00 0.00 1.00 6926.57 1.00 skill_5[19] 1.00 0.00 1.00 1.00 1.00 nan nan skill_5[20] 1.00 0.00 1.00 1.00 1.00 nan nan skill_5[21] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[0] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[1] 0.99 0.10 1.00 1.00 1.00 4020.57 1.00 skill_6[2] 0.50 0.50 0.00 0.00 1.00 -7660.98 1.00 skill_6[3] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[4] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[5] 0.50 0.50 0.00 0.00 1.00 -4056.58 1.00 skill_6[6] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[7] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_6[8] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[9] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[10] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[11] 0.09 0.29 0.00 0.00 0.00 2600.06 1.00 skill_6[12] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[13] 0.99 0.09 1.00 1.00 1.00 4078.84 1.00 skill_6[14] 0.99 0.11 1.00 1.00 1.00 4109.17 1.00 skill_6[15] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[16] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[17] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[18] 0.99 0.12 1.00 1.00 1.00 4128.28 1.00 skill_6[19] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[20] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[21] 1.00 0.02 1.00 1.00 1.00 nan 1.00
ds = az.from_numpyro(mcmc_00)
az.plot_trace(ds);
log_density_model_00, model_00_trace = log_density(
model_00,
(jnp.array(responses), skills_needed),
dict(prob_mistake=0.1, prob_guess=0.2),
{key: value.mean(0) for key, value in mcmc_00.get_samples().items()},
)
pe_model_00 = mcmc_00.get_extra_fields()["hmc_state.potential_energy"]
exp_log_density_00 = log_ppd(
model_00, mcmc_00.get_samples(), jnp.array(responses), skills_needed
)
# post_loglik_00 = log_likelihood(
# model_00, mcmc_00.get_samples(), jnp.array(responses), skills_needed,
# )
# post_loglik_00_res = np.concatenate(
# [obs[:, None] for obs in post_loglik_00.values()], axis=1
# )
# exp_log_density_00 = logsumexp(post_loglik_00_res, axis=0) - jnp.log(
# jnp.shape(post_loglik_00_res)[0]
# )
theta_model_00 = np.zeros((22, 7))
for i, param in enumerate(["skill_" + str(i) for i in range(7)]):
theta_model_00[:, i] = np.mean(mcmc_00.get_samples()[param], axis=0)
neg_log_proba_model_00 = neg_log_proba_score(theta_model_00, self_assessed.values)
def model_02(
graded_responses, skills_needed: List[List[int]], prob_mistake=0.1,
):
n_questions, n_participants = graded_responses.shape
n_skills = max(map(max, skills_needed)) + 1
with numpyro.plate("questions_plate", n_questions):
prob_guess = numpyro.sample("prob_guess", dist.Beta(2.5, 7.5))
participants_plate = numpyro.plate("participants_plate", n_participants)
with participants_plate:
skills = []
for s in range(n_skills):
skills.append(numpyro.sample("skill_{}".format(s), dist.Bernoulli(0.5)))
for q in range(n_questions):
has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])
prob_correct = (
has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess[q]
)
isCorrect = numpyro.sample(
"isCorrect_{}".format(q),
dist.Bernoulli(prob_correct).to_event(1),
obs=graded_responses[q],
)
nuts_kernel = NUTS(model_02)
kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)
mcmc_02 = MCMC(kernel, num_warmup=200, num_samples=1000, num_chains=4)
mcmc_02.run(
rng_key,
jnp.array(responses),
skills_needed,
extra_fields=(
"z",
"hmc_state.potential_energy",
"hmc_state.z",
"rng_key",
"hmc_state.rng_key",
),
)
mcmc_02.print_summary()
/home/benda/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/mcmc.py:269: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`. warnings.warn( sample: 100%|██████████| 1200/1200 [02:19<00:00, 8.60it/s, 7 steps of size 4.34e-01. acc. prob=0.88] sample: 100%|██████████| 1200/1200 [02:19<00:00, 8.59it/s, 7 steps of size 4.57e-01. acc. prob=0.86] sample: 100%|██████████| 1200/1200 [02:21<00:00, 8.47it/s, 7 steps of size 4.90e-01. acc. prob=0.86] sample: 100%|██████████| 1200/1200 [02:23<00:00, 8.35it/s, 7 steps of size 4.47e-01. acc. prob=0.86]
mean std median 5.0% 95.0% n_eff r_hat prob_guess[0] 0.26 0.12 0.25 0.05 0.43 3138.97 1.00 prob_guess[1] 0.28 0.12 0.27 0.08 0.46 3689.75 1.00 prob_guess[2] 0.32 0.13 0.31 0.10 0.51 5050.20 1.00 prob_guess[3] 0.33 0.13 0.32 0.11 0.53 3007.40 1.00 prob_guess[4] 0.29 0.12 0.28 0.09 0.48 2769.02 1.00 prob_guess[5] 0.20 0.11 0.19 0.04 0.38 5336.04 1.00 prob_guess[6] 0.39 0.13 0.39 0.18 0.60 4493.49 1.00 prob_guess[7] 0.38 0.13 0.37 0.16 0.59 5009.81 1.00 prob_guess[8] 0.39 0.13 0.38 0.17 0.59 3560.31 1.00 prob_guess[9] 0.32 0.13 0.31 0.10 0.52 3133.14 1.00 prob_guess[10] 0.30 0.13 0.29 0.09 0.49 3694.41 1.00 prob_guess[11] 0.19 0.11 0.18 0.03 0.35 4524.02 1.00 prob_guess[12] 0.23 0.12 0.22 0.04 0.41 3143.92 1.00 prob_guess[13] 0.20 0.11 0.19 0.03 0.36 5236.74 1.00 prob_guess[14] 0.22 0.11 0.21 0.04 0.39 3761.43 1.00 prob_guess[15] 0.35 0.12 0.34 0.13 0.54 3402.84 1.00 prob_guess[16] 0.31 0.12 0.30 0.10 0.49 4800.58 1.00 prob_guess[17] 0.24 0.12 0.23 0.05 0.43 4647.01 1.00 prob_guess[18] 0.61 0.11 0.61 0.44 0.78 3861.77 1.00 prob_guess[19] 0.56 0.11 0.56 0.38 0.73 4028.99 1.00 prob_guess[20] 0.51 0.11 0.51 0.33 0.70 3595.62 1.00 prob_guess[21] 0.11 0.07 0.10 0.01 0.21 4919.36 1.00 prob_guess[22] 0.44 0.11 0.44 0.27 0.63 3543.74 1.00 prob_guess[23] 0.28 0.13 0.27 0.07 0.47 3130.95 1.00 prob_guess[24] 0.26 0.12 0.25 0.07 0.46 2904.73 1.00 prob_guess[25] 0.34 0.13 0.33 0.11 0.52 4288.83 1.00 prob_guess[26] 0.52 0.15 0.53 0.29 0.76 958.60 1.00 prob_guess[27] 0.52 0.14 0.54 0.29 0.76 1093.10 1.00 prob_guess[28] 0.52 0.15 0.53 0.30 0.77 1114.17 1.00 prob_guess[29] 0.20 0.10 0.18 0.04 0.34 1932.30 1.00 prob_guess[30] 0.23 0.11 0.22 0.05 0.40 3297.66 1.00 prob_guess[31] 0.31 0.13 0.31 0.07 0.51 1391.98 1.00 prob_guess[32] 0.45 0.15 0.46 0.21 0.70 1048.96 1.01 prob_guess[33] 0.47 0.15 0.48 0.21 0.71 1007.86 1.00 prob_guess[34] 0.33 0.11 0.32 0.15 0.52 3329.14 1.00 prob_guess[35] 0.41 0.12 0.41 0.23 0.61 3670.40 1.00 prob_guess[36] 0.38 0.12 0.38 0.19 0.57 4165.09 1.00 prob_guess[37] 0.53 0.12 0.53 0.33 0.73 4196.24 1.00 prob_guess[38] 0.46 0.12 0.46 0.26 0.66 3673.99 1.00 prob_guess[39] 0.23 0.10 0.22 0.07 0.40 3564.23 1.00 prob_guess[40] 0.43 0.13 0.43 0.22 0.65 3616.51 1.00 prob_guess[41] 0.43 0.13 0.43 0.23 0.64 3123.85 1.00 prob_guess[42] 0.37 0.13 0.37 0.16 0.58 2582.66 1.00 prob_guess[43] 0.30 0.12 0.29 0.10 0.48 4110.90 1.00 prob_guess[44] 0.30 0.11 0.29 0.12 0.48 4875.23 1.00 prob_guess[45] 0.31 0.11 0.30 0.12 0.49 3179.36 1.00 prob_guess[46] 0.19 0.10 0.17 0.03 0.33 4516.82 1.00 prob_guess[47] 0.31 0.12 0.30 0.13 0.51 5223.94 1.00 skill_0[0] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[1] 1.00 0.03 1.00 1.00 1.00 nan 1.00 skill_0[2] 0.01 0.10 0.00 0.00 0.00 3876.81 1.00 skill_0[3] 1.00 0.04 1.00 1.00 1.00 nan 1.00 skill_0[4] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[5] 0.00 0.00 0.00 0.00 0.00 nan nan skill_0[6] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[7] 1.00 0.07 1.00 1.00 1.00 4041.91 1.00 skill_0[8] 0.92 0.27 1.00 1.00 1.00 3931.25 1.00 skill_0[9] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[10] 0.99 0.11 1.00 1.00 1.00 4110.15 1.00 skill_0[11] 0.89 0.31 1.00 0.00 1.00 4984.72 1.00 skill_0[12] 1.00 0.03 1.00 1.00 1.00 nan 1.00 skill_0[13] 0.28 0.45 0.00 0.00 1.00 3007.28 1.00 skill_0[14] 0.96 0.19 1.00 1.00 1.00 3995.70 1.00 skill_0[15] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[16] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[17] 1.00 0.03 1.00 1.00 1.00 nan 1.00 skill_0[18] 0.99 0.09 1.00 1.00 1.00 3971.27 1.00 skill_0[19] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[20] 1.00 0.00 1.00 1.00 1.00 nan nan skill_0[21] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_1[0] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[1] 0.98 0.15 1.00 1.00 1.00 3746.90 1.00 skill_1[2] 0.00 0.07 0.00 0.00 0.00 4040.30 1.00 skill_1[3] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_1[4] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[5] 0.00 0.06 0.00 0.00 0.00 4033.86 1.00 skill_1[6] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_1[7] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[8] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[9] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[10] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[11] 0.81 0.39 1.00 0.00 1.00 1393.43 1.00 skill_1[12] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_1[13] 1.00 0.05 1.00 1.00 1.00 4029.22 1.00 skill_1[14] 0.71 0.45 1.00 0.00 1.00 3865.86 1.00 skill_1[15] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[16] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[17] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[18] 1.00 0.04 1.00 1.00 1.00 nan 1.00 skill_1[19] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[20] 1.00 0.00 1.00 1.00 1.00 nan nan skill_1[21] 1.00 0.07 1.00 1.00 1.00 4045.54 1.00 skill_2[0] 0.51 0.50 1.00 0.00 1.00 13562.32 1.00 skill_2[1] 0.08 0.28 0.00 0.00 0.00 4902.13 1.00 skill_2[2] 0.01 0.11 0.00 0.00 0.00 3871.70 1.00 skill_2[3] 0.12 0.32 0.00 0.00 1.00 5291.34 1.00 skill_2[4] 0.51 0.50 1.00 0.00 1.00 12431.32 1.00 skill_2[5] 0.00 0.03 0.00 0.00 0.00 nan 1.00 skill_2[6] 0.51 0.50 1.00 0.00 1.00 20899.36 1.00 skill_2[7] 0.99 0.12 1.00 1.00 1.00 4125.13 1.00 skill_2[8] 0.52 0.50 1.00 0.00 1.00 14466.44 1.00 skill_2[9] 0.02 0.14 0.00 0.00 0.00 4122.87 1.00 skill_2[10] 0.09 0.29 0.00 0.00 0.00 4396.90 1.00 skill_2[11] 0.51 0.50 1.00 0.00 1.00 18257.86 1.00 skill_2[12] 0.51 0.50 1.00 0.00 1.00 13388.88 1.00 skill_2[13] 0.89 0.32 1.00 0.00 1.00 5149.99 1.00 skill_2[14] 0.14 0.34 0.00 0.00 1.00 5471.75 1.00 skill_2[15] 0.09 0.28 0.00 0.00 0.00 4937.10 1.00 skill_2[16] 0.18 0.38 0.00 0.00 1.00 5614.03 1.00 skill_2[17] 0.51 0.50 1.00 0.00 1.00 23114.37 1.00 skill_2[18] 0.51 0.50 1.00 0.00 1.00 20113.88 1.00 skill_2[19] 0.99 0.12 1.00 1.00 1.00 4116.94 1.00 skill_2[20] 0.51 0.50 1.00 0.00 1.00 15794.23 1.00 skill_2[21] 0.10 0.30 0.00 0.00 0.00 4538.81 1.00 skill_3[0] 0.52 0.50 1.00 0.00 1.00 1627.99 1.00 skill_3[1] 0.62 0.49 1.00 0.00 1.00 2340.25 1.00 skill_3[2] 0.61 0.49 1.00 0.00 1.00 2249.26 1.00 skill_3[3] 0.73 0.44 1.00 0.00 1.00 2846.19 1.00 skill_3[4] 0.11 0.31 0.00 0.00 1.00 1697.16 1.00 skill_3[5] 0.00 0.02 0.00 0.00 0.00 nan 1.00 skill_3[6] 0.52 0.50 1.00 0.00 1.00 1757.86 1.00 skill_3[7] 0.11 0.31 0.00 0.00 1.00 1598.81 1.00 skill_3[8] 0.94 0.24 1.00 1.00 1.00 3779.98 1.00 skill_3[9] 0.94 0.23 1.00 1.00 1.00 3426.03 1.00 skill_3[10] 0.94 0.24 1.00 1.00 1.00 3526.44 1.00 skill_3[11] 0.00 0.03 0.00 0.00 0.00 nan 1.00 skill_3[12] 0.52 0.50 1.00 0.00 1.00 1658.32 1.00 skill_3[13] 0.61 0.49 1.00 0.00 1.00 2149.84 1.00 skill_3[14] 0.52 0.50 1.00 0.00 1.00 1790.86 1.00 skill_3[15] 1.00 0.03 1.00 1.00 1.00 nan 1.00 skill_3[16] 0.52 0.50 1.00 0.00 1.00 1769.12 1.00 skill_3[17] 0.52 0.50 1.00 0.00 1.00 1683.65 1.00 skill_3[18] 0.94 0.23 1.00 1.00 1.00 3907.40 1.00 skill_3[19] 0.94 0.23 1.00 1.00 1.00 4002.44 1.00 skill_3[20] 0.95 0.23 1.00 1.00 1.00 3661.50 1.00 skill_3[21] 0.94 0.23 1.00 1.00 1.00 3765.80 1.00 skill_4[0] 0.99 0.08 1.00 1.00 1.00 4060.42 1.00 skill_4[1] 0.09 0.29 0.00 0.00 0.00 4379.23 1.00 skill_4[2] 0.49 0.50 0.00 0.00 1.00 10748.09 1.00 skill_4[3] 0.04 0.20 0.00 0.00 0.00 4252.62 1.00 skill_4[4] 0.99 0.07 1.00 1.00 1.00 4049.81 1.00 skill_4[5] 0.00 0.03 0.00 0.00 0.00 nan 1.00 skill_4[6] 0.91 0.29 1.00 1.00 1.00 4622.34 1.00 skill_4[7] 0.07 0.26 0.00 0.00 0.00 4727.83 1.00 skill_4[8] 0.99 0.07 1.00 1.00 1.00 3776.96 1.00 skill_4[9] 0.99 0.09 1.00 1.00 1.00 4067.26 1.00 skill_4[10] 0.39 0.49 0.00 0.00 1.00 8301.22 1.00 skill_4[11] 0.89 0.31 1.00 0.00 1.00 4541.33 1.00 skill_4[12] 0.38 0.49 0.00 0.00 1.00 7387.72 1.00 skill_4[13] 0.06 0.23 0.00 0.00 0.00 4535.34 1.00 skill_4[14] 0.27 0.44 0.00 0.00 1.00 5097.39 1.00 skill_4[15] 0.31 0.46 0.00 0.00 1.00 7812.16 1.00 skill_4[16] 0.87 0.34 1.00 0.00 1.00 5593.14 1.00 skill_4[17] 0.93 0.26 1.00 1.00 1.00 4655.86 1.00 skill_4[18] 0.31 0.46 0.00 0.00 1.00 7551.88 1.00 skill_4[19] 0.91 0.29 1.00 1.00 1.00 4885.18 1.00 skill_4[20] 0.99 0.08 1.00 1.00 1.00 4057.65 1.00 skill_4[21] 0.91 0.28 1.00 1.00 1.00 4967.61 1.00 skill_5[0] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_5[1] 0.00 0.00 0.00 0.00 0.00 nan nan skill_5[2] 0.29 0.45 0.00 0.00 1.00 4111.53 1.00 skill_5[3] 0.83 0.38 1.00 0.00 1.00 5250.81 1.00 skill_5[4] 0.90 0.30 1.00 1.00 1.00 4413.22 1.00 skill_5[5] 0.00 0.00 0.00 0.00 0.00 nan nan skill_5[6] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_5[7] 0.99 0.07 1.00 1.00 1.00 4051.40 1.00 skill_5[8] 1.00 0.05 1.00 1.00 1.00 4026.94 1.00 skill_5[9] 0.83 0.37 1.00 0.00 1.00 5908.88 1.00 skill_5[10] 0.43 0.50 0.00 0.00 1.00 6477.28 1.00 skill_5[11] 0.00 0.02 0.00 0.00 0.00 nan 1.00 skill_5[12] 0.99 0.11 1.00 1.00 1.00 4104.54 1.00 skill_5[13] 0.30 0.46 0.00 0.00 1.00 4174.26 1.00 skill_5[14] 0.99 0.10 1.00 1.00 1.00 4099.45 1.00 skill_5[15] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_5[16] 0.99 0.10 1.00 1.00 1.00 4096.47 1.00 skill_5[17] 0.28 0.45 0.00 0.00 1.00 4807.04 1.00 skill_5[18] 0.29 0.45 0.00 0.00 1.00 4008.14 1.00 skill_5[19] 1.00 0.00 1.00 1.00 1.00 nan nan skill_5[20] 1.00 0.00 1.00 1.00 1.00 nan nan skill_5[21] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_6[0] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[1] 0.85 0.36 1.00 0.00 1.00 4567.60 1.00 skill_6[2] 0.50 0.50 0.00 0.00 1.00 -9642.91 1.00 skill_6[3] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_6[4] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[5] 0.50 0.50 0.00 0.00 1.00 -4032.19 1.00 skill_6[6] 1.00 0.04 1.00 1.00 1.00 4023.54 1.00 skill_6[7] 0.99 0.08 1.00 1.00 1.00 3523.17 1.00 skill_6[8] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[9] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_6[10] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[11] 0.11 0.31 0.00 0.00 1.00 1995.29 1.00 skill_6[12] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_6[13] 0.84 0.36 1.00 0.00 1.00 4683.67 1.00 skill_6[14] 0.82 0.39 1.00 0.00 1.00 3364.55 1.00 skill_6[15] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_6[16] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[17] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[18] 0.87 0.33 1.00 0.00 1.00 4571.07 1.00 skill_6[19] 1.00 0.00 1.00 1.00 1.00 nan nan skill_6[20] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill_6[21] 0.99 0.12 1.00 1.00 1.00 3642.66 1.00
ds = az.from_numpyro(mcmc_02)
az.plot_trace(ds);
log_density_model_02, model_02_trace = log_density(
model_02,
(jnp.array(responses), skills_needed),
dict(prob_mistake=0.1),
{key: value.mean(0) for key, value in mcmc_02.get_samples().items()},
)
pe_model_02 = mcmc_02.get_extra_fields()["hmc_state.potential_energy"]
exp_log_density_02 = log_ppd(
model_02, mcmc_02.get_samples(), jnp.array(responses), skills_needed
)
# post_loglik_02 = log_likelihood(
# model_02, mcmc_02.get_samples(), jnp.array(responses), skills_needed,
# )
# post_loglik_02_res = np.concatenate(
# [obs[:, None] for obs in post_loglik_02.values()], axis=1
# )
# exp_log_density_02 = logsumexp(post_loglik_02_res, axis=0) - jnp.log(
# jnp.shape(post_loglik_02_res)[0]
# )
theta_model_02 = np.zeros((22, 7))
for i, param in enumerate(["skill_" + str(i) for i in range(7)]):
theta_model_02[:, i] = np.mean(mcmc_02.get_samples()[param], axis=0)
neg_log_proba_model_02 = neg_log_proba_score(theta_model_02, self_assessed.values)
print(
"Expected log joint density of model_00: {:.2f} +/- {:.2f}".format(
np.mean(-pe_model_00), np.std(-pe_model_00)
)
)
print(
"Expected log joint density of model_02: {:.2f} +/- {:.2f}".format(
np.mean(-pe_model_02), np.std(-pe_model_02)
)
)
Expected log joint density of model_00: -618.62 +/- 4.10 Expected log joint density of model_02: -652.77 +/- 7.04
plot_bars(
np.array([np.mean(pe_model_00), np.mean(pe_model_02)])[None, :],
["Original", "Learned"],
["Overall"],
tick_step=50.0,
ylabel="negative Expected log density",
yerr=[[np.std(pe_model_00)], [np.std(pe_model_02)]],
)
print(
"Expected log joint density of model_00 from `log_density`: {:.2f}".format(
log_density_model_00
)
)
print(
"Expected log joint density of model_02 from `log_density`: {:.2f}".format(
log_density_model_02
)
)
Expected log joint density of model_00 from `log_density`: -569.03 Expected log joint density of model_02 from `log_density`: -501.29
plot_bars(
-np.array([log_density_model_00, log_density_model_02])[None, :],
["Original", "Learned"],
["Overall"],
tick_step=50.0,
ylabel="negative log density",
yerr=[[np.std(pe_model_00)], [np.std(pe_model_02)]],
)
pd.DataFrame(
np.array([np.sum(exp_log_density_00), np.sum(exp_log_density_02)])[None, :],
columns=["Original", "Learned"],
index=["Overall"],
).plot(
kind="bar", color=["b", "r"], ylabel="Log pointwise predictive density",
)
<AxesSubplot:ylabel='Log pointwise predictive density'>
plot_bars(
np.array([neg_log_proba_model_00.mean(), neg_log_proba_model_02.mean()])[None, :],
["Original", "Learned"],
["Overall"],
)
plot_bars(
np.concatenate(
[
neg_log_proba_model_00.mean(0)[:, None],
neg_log_proba_model_02.mean(0)[:, None],
],
axis=1,
),
["Original", "Learned"],
["Core", "OOP", "Life Cycle", "Web Apps Skills", "Desktop apps", "SQL", "C#"],
)
def model_03(
graded_responses, skills_needed: np.array, prob_mistake=0.1, prob_guess=0.2
):
assert graded_responses.shape[0] == skills_needed.shape[0]
n_questions, n_participants = graded_responses.shape
n_skills = skills_needed.shape[1]
questions_plate = numpyro.plate("questions_plate", n_questions)
# skills.shape == (n_participants, n_skills)
with numpyro.plate("participants_plate", n_participants, dim=-2):
with numpyro.plate("skills_plate", n_skills):
skills = numpyro.sample("skill", dist.Bernoulli(0.5))
with questions_plate:
# shape: people x questions x skills
# astype(bool) is needed for the log density
relevant_skills = skills[:, None, :].astype(bool) | (~skills_needed)
# shape: people x questions
has_skill = jnp.all(relevant_skills, -1)
prob_correct = has_skill * (1 - prob_mistake) + (1 - has_skill) * prob_guess
is_correct = numpyro.sample(
"isCorrect", dist.Bernoulli(prob_correct), obs=graded_responses.T
)
numpyro.render_model(
model_03,
(jnp.array(responses), jnp.array(skills_key.astype(bool))),
dict(prob_mistake=0.1),
render_distributions=True,
)
nuts_kernel = NUTS(model_03)
kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)
mcmc_03 = MCMC(kernel, num_warmup=200, num_samples=1000, num_chains=4)
mcmc_03.run(
rng_key,
jnp.array(responses),
jnp.array(skills_key.astype(bool)),
extra_fields=(
"z",
"hmc_state.potential_energy",
"hmc_state.z",
"rng_key",
"hmc_state.rng_key",
),
)
mcmc_03.print_summary()
/home/benda/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/mcmc.py:269: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`. warnings.warn( sample: 100%|██████████| 1200/1200 [00:06<00:00, 175.62it/s, 1 steps of size 1.18e+37. acc. prob=1.00] sample: 100%|██████████| 1200/1200 [00:02<00:00, 543.05it/s, 1 steps of size 1.18e+37. acc. prob=1.00] sample: 100%|██████████| 1200/1200 [00:02<00:00, 591.14it/s, 1 steps of size 1.18e+37. acc. prob=1.00] sample: 100%|██████████| 1200/1200 [00:02<00:00, 574.17it/s, 1 steps of size 1.18e+37. acc. prob=1.00]
mean std median 5.0% 95.0% n_eff r_hat skill[0,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[0,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[0,2] 0.98 0.13 1.00 1.00 1.00 4007.27 1.00 skill[0,3] 0.99 0.09 1.00 1.00 1.00 3523.86 1.00 skill[0,4] 1.00 0.00 1.00 1.00 1.00 nan nan skill[0,5] 1.00 0.00 1.00 1.00 1.00 nan nan skill[0,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[1,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[1,1] 1.00 0.05 1.00 1.00 1.00 4027.48 1.00 skill[1,2] 0.59 0.49 1.00 0.00 1.00 21255.70 1.00 skill[1,3] 0.99 0.09 1.00 1.00 1.00 3855.32 1.00 skill[1,4] 0.16 0.36 0.00 0.00 1.00 6398.46 1.00 skill[1,5] 0.00 0.00 0.00 0.00 0.00 nan nan skill[1,6] 0.99 0.10 1.00 1.00 1.00 4091.70 1.00 skill[2,0] 0.01 0.12 0.00 0.00 0.00 4119.50 1.00 skill[2,1] 0.00 0.06 0.00 0.00 0.00 4034.74 1.00 skill[2,2] 0.04 0.20 0.00 0.00 0.00 4399.36 1.00 skill[2,3] 0.99 0.09 1.00 1.00 1.00 3802.22 1.00 skill[2,4] 0.59 0.49 1.00 0.00 1.00 29538.31 1.00 skill[2,5] 0.78 0.42 1.00 0.00 1.00 7515.12 1.00 skill[2,6] 0.49 0.50 0.00 0.00 1.00 -7988.59 1.00 skill[3,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[3,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[3,2] 0.59 0.49 1.00 0.00 1.00 22782.84 1.00 skill[3,3] 0.99 0.10 1.00 1.00 1.00 3829.47 1.00 skill[3,4] 0.14 0.35 0.00 0.00 1.00 5478.77 1.00 skill[3,5] 0.99 0.09 1.00 1.00 1.00 4080.44 1.00 skill[3,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[4,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[4,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[4,2] 0.98 0.13 1.00 1.00 1.00 4160.78 1.00 skill[4,3] 0.78 0.41 1.00 0.00 1.00 6547.16 1.00 skill[4,4] 1.00 0.00 1.00 1.00 1.00 nan nan skill[4,5] 0.99 0.10 1.00 1.00 1.00 3859.70 1.00 skill[4,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[5,0] 0.00 0.00 0.00 0.00 0.00 nan nan skill[5,1] 0.00 0.06 0.00 0.00 0.00 4038.41 1.00 skill[5,2] 0.00 0.00 0.00 0.00 0.00 nan nan skill[5,3] 0.00 0.00 0.00 0.00 0.00 nan nan skill[5,4] 0.00 0.03 0.00 0.00 0.00 nan 1.00 skill[5,5] 0.00 0.00 0.00 0.00 0.00 nan nan skill[5,6] 0.50 0.50 0.00 0.00 1.00 -4834.32 1.00 skill[6,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[6,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[6,2] 0.98 0.15 1.00 1.00 1.00 4195.02 1.00 skill[6,3] 0.99 0.08 1.00 1.00 1.00 4055.22 1.00 skill[6,4] 1.00 0.06 1.00 1.00 1.00 3579.01 1.00 skill[6,5] 1.00 0.00 1.00 1.00 1.00 nan nan skill[6,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[7,0] 1.00 0.03 1.00 1.00 1.00 nan 1.00 skill[7,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[7,2] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[7,3] 0.79 0.41 1.00 0.00 1.00 6567.00 1.00 skill[7,4] 0.15 0.36 0.00 0.00 1.00 4806.60 1.00 skill[7,5] 1.00 0.00 1.00 1.00 1.00 nan nan skill[7,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[8,0] 0.99 0.11 1.00 1.00 1.00 4117.96 1.00 skill[8,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[8,2] 0.98 0.13 1.00 1.00 1.00 3948.65 1.00 skill[8,3] 1.00 0.00 1.00 1.00 1.00 nan nan skill[8,4] 1.00 0.00 1.00 1.00 1.00 nan nan skill[8,5] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[8,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[9,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[9,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[9,2] 0.04 0.20 0.00 0.00 0.00 4393.49 1.00 skill[9,3] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[9,4] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[9,5] 0.99 0.09 1.00 1.00 1.00 3623.58 1.00 skill[9,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[10,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[10,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[10,2] 0.58 0.49 1.00 0.00 1.00 29410.68 1.00 skill[10,3] 1.00 0.00 1.00 1.00 1.00 nan nan skill[10,4] 0.87 0.34 1.00 0.00 1.00 5665.24 1.00 skill[10,5] 0.78 0.41 1.00 0.00 1.00 6876.10 1.00 skill[10,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[11,0] 0.97 0.16 1.00 1.00 1.00 4212.01 1.00 skill[11,1] 0.94 0.24 1.00 1.00 1.00 3236.26 1.00 skill[11,2] 0.98 0.13 1.00 1.00 1.00 3675.31 1.00 skill[11,3] 0.00 0.04 0.00 0.00 0.00 4023.54 1.00 skill[11,4] 0.99 0.07 1.00 1.00 1.00 4050.40 1.00 skill[11,5] 0.00 0.04 0.00 0.00 0.00 nan 1.00 skill[11,6] 0.10 0.29 0.00 0.00 0.00 2113.95 1.00 skill[12,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[12,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[12,2] 0.98 0.14 1.00 1.00 1.00 4179.95 1.00 skill[12,3] 0.99 0.09 1.00 1.00 1.00 3746.76 1.00 skill[12,4] 0.87 0.34 1.00 0.00 1.00 5249.61 1.00 skill[12,5] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[12,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[13,0] 0.66 0.47 1.00 0.00 1.00 12244.52 1.00 skill[13,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[13,2] 0.98 0.14 1.00 1.00 1.00 4165.87 1.00 skill[13,3] 0.99 0.09 1.00 1.00 1.00 4064.02 1.00 skill[13,4] 0.15 0.36 0.00 0.00 1.00 5670.10 1.00 skill[13,5] 0.78 0.41 1.00 0.00 1.00 7449.36 1.00 skill[13,6] 0.99 0.09 1.00 1.00 1.00 3841.18 1.00 skill[14,0] 1.00 0.04 1.00 1.00 1.00 4021.12 1.00 skill[14,1] 0.96 0.19 1.00 1.00 1.00 3972.18 1.00 skill[14,2] 0.59 0.49 1.00 0.00 1.00 22337.70 1.00 skill[14,3] 0.99 0.09 1.00 1.00 1.00 4069.47 1.00 skill[14,4] 0.86 0.34 1.00 0.00 1.00 5230.53 1.00 skill[14,5] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[14,6] 0.99 0.12 1.00 1.00 1.00 3793.82 1.00 skill[15,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[15,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[15,2] 0.59 0.49 1.00 0.00 1.00 27452.34 1.00 skill[15,3] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[15,4] 0.86 0.35 1.00 0.00 1.00 5793.36 1.00 skill[15,5] 1.00 0.00 1.00 1.00 1.00 nan nan skill[15,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[16,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[16,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[16,2] 0.59 0.49 1.00 0.00 1.00 28115.25 1.00 skill[16,3] 0.99 0.09 1.00 1.00 1.00 3960.73 1.00 skill[16,4] 1.00 0.06 1.00 1.00 1.00 4035.50 1.00 skill[16,5] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[16,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[17,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[17,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[17,2] 0.98 0.12 1.00 1.00 1.00 3978.82 1.00 skill[17,3] 0.99 0.09 1.00 1.00 1.00 3907.57 1.00 skill[17,4] 0.99 0.08 1.00 1.00 1.00 3851.72 1.00 skill[17,5] 0.78 0.41 1.00 0.00 1.00 6662.39 1.00 skill[17,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[18,0] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[18,1] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[18,2] 0.98 0.14 1.00 1.00 1.00 4027.09 1.00 skill[18,3] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[18,4] 0.87 0.34 1.00 0.00 1.00 5572.52 1.00 skill[18,5] 0.79 0.41 1.00 0.00 1.00 6977.37 1.00 skill[18,6] 0.99 0.12 1.00 1.00 1.00 3849.26 1.00 skill[19,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[19,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[19,2] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[19,3] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[19,4] 1.00 0.05 1.00 1.00 1.00 4027.48 1.00 skill[19,5] 1.00 0.00 1.00 1.00 1.00 nan nan skill[19,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[20,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[20,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[20,2] 0.98 0.14 1.00 1.00 1.00 3947.82 1.00 skill[20,3] 1.00 0.00 1.00 1.00 1.00 nan nan skill[20,4] 1.00 0.00 1.00 1.00 1.00 nan nan skill[20,5] 1.00 0.00 1.00 1.00 1.00 nan nan skill[20,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[21,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[21,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[21,2] 0.59 0.49 1.00 0.00 1.00 25455.78 1.00 skill[21,3] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[21,4] 0.99 0.08 1.00 1.00 1.00 4057.64 1.00 skill[21,5] 1.00 0.00 1.00 1.00 1.00 nan nan skill[21,6] 1.00 0.02 1.00 1.00 1.00 nan 1.00
ds = az.from_numpyro(mcmc_03)
az.plot_trace(ds);
log_density_model_03, model_03_trace = log_density(
model_03,
(jnp.array(responses), jnp.array(skills_key.astype(bool))),
{"prob_mistake": 0.1, "prob_guess": 0.2},
{key: value.mean(0) for key, value in mcmc_03.get_samples().items()},
)
pe_model_03 = mcmc_03.get_extra_fields()["hmc_state.potential_energy"]
exp_log_density_03 = log_ppd(
model_03,
mcmc_03.get_samples(),
jnp.array(responses),
jnp.array(skills_key.astype(bool)),
)
# post_loglik_03 = log_likelihood(
# model_03,
# mcmc_03.get_samples(),
# jnp.array(responses),
# jnp.array(skills_key.astype(bool)),
# )
# post_loglik_03_res = np.concatenate(
# [obs[:, None] for obs in post_loglik_03.values()], axis=1
# )
# exp_log_density_03 = logsumexp(post_loglik_03_res, axis=0) - jnp.log(
# jnp.shape(post_loglik_03_res)[0]
# )
neg_log_proba_model_03 = neg_log_proba_score(
mcmc_03.get_samples()["skill"].mean(0), self_assessed.values
)
def model_04(
graded_responses, skills_needed: np.array, prob_mistake=0.1,
):
assert graded_responses.shape[0] == skills_needed.shape[0]
n_questions, n_participants = graded_responses.shape
n_skills = skills_needed.shape[1]
questions_plate = numpyro.plate("questions_plate", n_questions)
with questions_plate:
prob_guess = numpyro.sample("prob_guess", dist.Beta(2.5, 7.5))
# skills.shape == (n_participants, n_skills)
with numpyro.plate("participants_plate", n_participants, dim=-2):
with numpyro.plate("skills_plate", n_skills):
skills = numpyro.sample("skill", dist.Bernoulli(0.5))
with questions_plate:
# shape: people x questions x skills
# astype(bool) is needed for the log density
relevant_skills = skills[:, None, :].astype(bool) | (~skills_needed)
# shape: people x questions
has_skill = jnp.all(relevant_skills, -1)
prob_correct = has_skill * (1 - prob_mistake) + (1 - has_skill) * prob_guess
is_correct = numpyro.sample(
"isCorrect", dist.Bernoulli(prob_correct), obs=graded_responses.T
)
numpyro.render_model(
model_04,
(jnp.array(responses), jnp.array(skills_key.astype(bool))),
dict(prob_mistake=0.1),
render_distributions=True,
)
nuts_kernel = NUTS(model_04)
kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)
mcmc_04 = MCMC(kernel, num_warmup=200, num_samples=1000, num_chains=4)
mcmc_04.run(
rng_key,
jnp.array(responses),
jnp.array(skills_key.astype(bool)),
extra_fields=(
"z",
"hmc_state.potential_energy",
"hmc_state.z",
"rng_key",
"hmc_state.rng_key",
),
)
mcmc_04.print_summary()
/home/benda/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/mcmc.py:269: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`. warnings.warn( sample: 100%|██████████| 1200/1200 [00:11<00:00, 104.02it/s, 7 steps of size 4.58e-01. acc. prob=0.86] sample: 100%|██████████| 1200/1200 [00:02<00:00, 476.34it/s, 7 steps of size 4.85e-01. acc. prob=0.87] sample: 100%|██████████| 1200/1200 [00:02<00:00, 468.03it/s, 15 steps of size 4.40e-01. acc. prob=0.87] sample: 100%|██████████| 1200/1200 [00:02<00:00, 485.38it/s, 7 steps of size 4.67e-01. acc. prob=0.87]
mean std median 5.0% 95.0% n_eff r_hat prob_guess[0] 0.26 0.12 0.25 0.06 0.45 3559.33 1.00 prob_guess[1] 0.28 0.12 0.26 0.08 0.47 4148.10 1.00 prob_guess[2] 0.32 0.12 0.31 0.11 0.51 5338.21 1.00 prob_guess[3] 0.33 0.13 0.32 0.12 0.54 3213.07 1.00 prob_guess[4] 0.29 0.12 0.28 0.09 0.48 3870.42 1.00 prob_guess[5] 0.20 0.11 0.19 0.03 0.36 4129.39 1.00 prob_guess[6] 0.39 0.13 0.39 0.17 0.61 2540.90 1.00 prob_guess[7] 0.38 0.13 0.37 0.17 0.60 4369.53 1.00 prob_guess[8] 0.39 0.13 0.38 0.18 0.61 4272.25 1.00 prob_guess[9] 0.32 0.13 0.31 0.10 0.53 3692.41 1.00 prob_guess[10] 0.30 0.13 0.28 0.09 0.49 3363.10 1.00 prob_guess[11] 0.19 0.10 0.18 0.04 0.35 3999.78 1.00 prob_guess[12] 0.23 0.12 0.22 0.04 0.40 3950.84 1.00 prob_guess[13] 0.20 0.11 0.19 0.03 0.37 4379.95 1.00 prob_guess[14] 0.23 0.12 0.21 0.04 0.40 5252.86 1.00 prob_guess[15] 0.35 0.12 0.34 0.15 0.55 3593.05 1.00 prob_guess[16] 0.31 0.12 0.30 0.11 0.51 4635.89 1.00 prob_guess[17] 0.24 0.12 0.23 0.06 0.44 5154.93 1.00 prob_guess[18] 0.60 0.11 0.61 0.43 0.78 4539.17 1.00 prob_guess[19] 0.56 0.11 0.57 0.38 0.74 4090.17 1.00 prob_guess[20] 0.51 0.11 0.51 0.32 0.69 3279.92 1.00 prob_guess[21] 0.11 0.07 0.10 0.01 0.21 5084.10 1.00 prob_guess[22] 0.44 0.11 0.44 0.25 0.62 3797.12 1.00 prob_guess[23] 0.28 0.12 0.27 0.10 0.48 4153.58 1.00 prob_guess[24] 0.27 0.13 0.25 0.07 0.47 3105.01 1.00 prob_guess[25] 0.33 0.13 0.33 0.14 0.54 4329.69 1.00 prob_guess[26] 0.52 0.15 0.52 0.28 0.75 966.11 1.00 prob_guess[27] 0.51 0.15 0.53 0.27 0.75 949.81 1.00 prob_guess[28] 0.51 0.15 0.52 0.26 0.75 877.79 1.00 prob_guess[29] 0.20 0.10 0.19 0.04 0.35 1588.94 1.00 prob_guess[30] 0.22 0.11 0.21 0.04 0.39 3467.77 1.00 prob_guess[31] 0.31 0.14 0.31 0.07 0.51 1180.77 1.00 prob_guess[32] 0.45 0.15 0.45 0.19 0.68 947.27 1.01 prob_guess[33] 0.46 0.15 0.47 0.18 0.69 937.23 1.00 prob_guess[34] 0.33 0.11 0.32 0.14 0.51 4058.15 1.00 prob_guess[35] 0.41 0.12 0.41 0.21 0.59 3525.86 1.00 prob_guess[36] 0.38 0.12 0.38 0.18 0.57 2941.52 1.00 prob_guess[37] 0.53 0.12 0.53 0.35 0.72 3940.80 1.00 prob_guess[38] 0.46 0.12 0.46 0.26 0.65 3547.73 1.00 prob_guess[39] 0.23 0.10 0.21 0.06 0.39 4288.94 1.00 prob_guess[40] 0.43 0.13 0.44 0.23 0.65 3238.92 1.00 prob_guess[41] 0.43 0.13 0.43 0.22 0.64 2791.93 1.00 prob_guess[42] 0.37 0.13 0.37 0.17 0.58 2782.34 1.00 prob_guess[43] 0.30 0.12 0.29 0.10 0.49 3361.57 1.00 prob_guess[44] 0.30 0.11 0.29 0.11 0.47 4594.32 1.00 prob_guess[45] 0.31 0.12 0.30 0.12 0.50 3500.36 1.00 prob_guess[46] 0.19 0.10 0.17 0.04 0.34 4600.93 1.00 prob_guess[47] 0.31 0.12 0.30 0.11 0.49 3728.88 1.00 skill[0,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[0,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[0,2] 0.51 0.50 1.00 0.00 1.00 19355.74 1.00 skill[0,3] 0.55 0.50 1.00 0.00 1.00 1375.74 1.00 skill[0,4] 0.99 0.07 1.00 1.00 1.00 4047.97 1.00 skill[0,5] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[0,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[1,0] 1.00 0.03 1.00 1.00 1.00 nan 1.00 skill[1,1] 0.97 0.16 1.00 1.00 1.00 4164.21 1.00 skill[1,2] 0.09 0.29 0.00 0.00 0.00 4570.44 1.00 skill[1,3] 0.64 0.48 1.00 0.00 1.00 1961.81 1.00 skill[1,4] 0.10 0.30 0.00 0.00 0.00 5108.02 1.00 skill[1,5] 0.00 0.00 0.00 0.00 0.00 nan nan skill[1,6] 0.84 0.36 1.00 0.00 1.00 4275.58 1.00 skill[2,0] 0.01 0.10 0.00 0.00 0.00 4059.28 1.00 skill[2,1] 0.00 0.06 0.00 0.00 0.00 4031.35 1.00 skill[2,2] 0.01 0.10 0.00 0.00 0.00 3922.60 1.00 skill[2,3] 0.63 0.48 1.00 0.00 1.00 2274.31 1.00 skill[2,4] 0.49 0.50 0.00 0.00 1.00 9826.82 1.00 skill[2,5] 0.30 0.46 0.00 0.00 1.00 3482.30 1.00 skill[2,6] 0.49 0.50 0.00 0.00 1.00 -17494.37 1.00 skill[3,0] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[3,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[3,2] 0.12 0.33 0.00 0.00 1.00 5176.59 1.00 skill[3,3] 0.75 0.43 1.00 0.00 1.00 2179.67 1.00 skill[3,4] 0.04 0.19 0.00 0.00 0.00 4350.17 1.00 skill[3,5] 0.83 0.38 1.00 0.00 1.00 4492.92 1.00 skill[3,6] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[4,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[4,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[4,2] 0.51 0.50 1.00 0.00 1.00 18394.96 1.00 skill[4,3] 0.12 0.32 0.00 0.00 1.00 1499.65 1.00 skill[4,4] 0.99 0.08 1.00 1.00 1.00 3601.24 1.00 skill[4,5] 0.90 0.30 1.00 0.00 1.00 4683.15 1.00 skill[4,6] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[5,0] 0.00 0.00 0.00 0.00 0.00 nan nan skill[5,1] 0.01 0.08 0.00 0.00 0.00 4052.38 1.00 skill[5,2] 0.00 0.02 0.00 0.00 0.00 nan 1.00 skill[5,3] 0.00 0.02 0.00 0.00 0.00 nan 1.00 skill[5,4] 0.00 0.04 0.00 0.00 0.00 nan 1.00 skill[5,5] 0.00 0.00 0.00 0.00 0.00 nan nan skill[5,6] 0.50 0.50 0.00 0.00 1.00 -5206.23 1.00 skill[6,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[6,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[6,2] 0.52 0.50 1.00 0.00 1.00 12780.83 1.00 skill[6,3] 0.55 0.50 1.00 0.00 1.00 1388.51 1.00 skill[6,4] 0.91 0.28 1.00 1.00 1.00 4624.31 1.00 skill[6,5] 1.00 0.03 1.00 1.00 1.00 nan 1.00 skill[6,6] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[7,0] 0.99 0.08 1.00 1.00 1.00 4055.12 1.00 skill[7,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[7,2] 0.99 0.11 1.00 1.00 1.00 4114.24 1.00 skill[7,3] 0.11 0.31 0.00 0.00 1.00 1350.89 1.00 skill[7,4] 0.08 0.27 0.00 0.00 0.00 4814.45 1.00 skill[7,5] 0.99 0.08 1.00 1.00 1.00 4061.79 1.00 skill[7,6] 0.99 0.08 1.00 1.00 1.00 4057.65 1.00 skill[8,0] 0.92 0.28 1.00 1.00 1.00 3902.11 1.00 skill[8,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[8,2] 0.51 0.50 1.00 0.00 1.00 15878.29 1.00 skill[8,3] 0.95 0.22 1.00 1.00 1.00 3332.05 1.00 skill[8,4] 0.99 0.07 1.00 1.00 1.00 4047.05 1.00 skill[8,5] 0.99 0.07 1.00 1.00 1.00 4048.60 1.00 skill[8,6] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[9,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[9,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[9,2] 0.02 0.15 0.00 0.00 0.00 4198.74 1.00 skill[9,3] 0.95 0.22 1.00 1.00 1.00 3641.48 1.00 skill[9,4] 1.00 0.07 1.00 1.00 1.00 4038.19 1.00 skill[9,5] 0.84 0.37 1.00 0.00 1.00 6019.87 1.00 skill[9,6] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[10,0] 0.99 0.10 1.00 1.00 1.00 3910.16 1.00 skill[10,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[10,2] 0.10 0.30 0.00 0.00 0.00 4774.40 1.00 skill[10,3] 0.95 0.23 1.00 1.00 1.00 3300.64 1.00 skill[10,4] 0.39 0.49 0.00 0.00 1.00 9775.40 1.00 skill[10,5] 0.42 0.49 0.00 0.00 1.00 5421.53 1.00 skill[10,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[11,0] 0.90 0.30 1.00 0.00 1.00 4916.07 1.00 skill[11,1] 0.80 0.40 1.00 0.00 1.00 1502.96 1.00 skill[11,2] 0.51 0.50 1.00 0.00 1.00 13021.89 1.00 skill[11,3] 0.00 0.02 0.00 0.00 0.00 nan 1.00 skill[11,4] 0.89 0.31 1.00 0.00 1.00 4794.16 1.00 skill[11,5] 0.00 0.04 0.00 0.00 0.00 nan 1.00 skill[11,6] 0.12 0.33 0.00 0.00 1.00 1809.58 1.00 skill[12,0] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[12,1] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[12,2] 0.52 0.50 1.00 0.00 1.00 15921.68 1.00 skill[12,3] 0.54 0.50 1.00 0.00 1.00 1328.37 1.00 skill[12,4] 0.38 0.49 0.00 0.00 1.00 9044.46 1.00 skill[12,5] 0.99 0.12 1.00 1.00 1.00 3963.33 1.00 skill[12,6] 1.00 0.03 1.00 1.00 1.00 nan 1.00 skill[13,0] 0.28 0.45 0.00 0.00 1.00 3375.10 1.00 skill[13,1] 1.00 0.05 1.00 1.00 1.00 4027.48 1.00 skill[13,2] 0.89 0.31 1.00 0.00 1.00 5216.81 1.00 skill[13,3] 0.64 0.48 1.00 0.00 1.00 2127.09 1.00 skill[13,4] 0.05 0.22 0.00 0.00 0.00 4485.42 1.00 skill[13,5] 0.29 0.45 0.00 0.00 1.00 4010.83 1.00 skill[13,6] 0.84 0.36 1.00 0.00 1.00 4854.66 1.00 skill[14,0] 0.97 0.18 1.00 1.00 1.00 3480.36 1.00 skill[14,1] 0.70 0.46 1.00 0.00 1.00 3942.63 1.00 skill[14,2] 0.14 0.35 0.00 0.00 1.00 4951.22 1.00 skill[14,3] 0.54 0.50 1.00 0.00 1.00 1440.23 1.00 skill[14,4] 0.26 0.44 0.00 0.00 1.00 5934.68 1.00 skill[14,5] 0.99 0.11 1.00 1.00 1.00 4111.41 1.00 skill[14,6] 0.83 0.38 1.00 0.00 1.00 3874.87 1.00 skill[15,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[15,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[15,2] 0.09 0.29 0.00 0.00 0.00 5063.69 1.00 skill[15,3] 1.00 0.04 1.00 1.00 1.00 nan 1.00 skill[15,4] 0.31 0.46 0.00 0.00 1.00 7789.10 1.00 skill[15,5] 1.00 0.00 1.00 1.00 1.00 nan nan skill[15,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[16,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[16,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[16,2] 0.17 0.37 0.00 0.00 1.00 5913.06 1.00 skill[16,3] 0.54 0.50 1.00 0.00 1.00 1330.97 1.00 skill[16,4] 0.86 0.35 1.00 0.00 1.00 4722.25 1.00 skill[16,5] 0.99 0.11 1.00 1.00 1.00 4107.76 1.00 skill[16,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[17,0] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[17,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[17,2] 0.52 0.50 1.00 0.00 1.00 15182.21 1.00 skill[17,3] 0.54 0.50 1.00 0.00 1.00 1440.92 1.00 skill[17,4] 0.93 0.25 1.00 1.00 1.00 4656.16 1.00 skill[17,5] 0.29 0.45 0.00 0.00 1.00 4920.49 1.00 skill[17,6] 1.00 0.00 1.00 1.00 1.00 nan nan skill[18,0] 0.99 0.10 1.00 1.00 1.00 4035.30 1.00 skill[18,1] 1.00 0.04 1.00 1.00 1.00 4023.54 1.00 skill[18,2] 0.51 0.50 1.00 0.00 1.00 18334.93 1.00 skill[18,3] 0.95 0.23 1.00 1.00 1.00 3651.53 1.00 skill[18,4] 0.30 0.46 0.00 0.00 1.00 7176.80 1.00 skill[18,5] 0.29 0.45 0.00 0.00 1.00 3983.76 1.00 skill[18,6] 0.86 0.35 1.00 0.00 1.00 4525.54 1.00 skill[19,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[19,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[19,2] 0.99 0.11 1.00 1.00 1.00 4110.54 1.00 skill[19,3] 0.95 0.22 1.00 1.00 1.00 3768.78 1.00 skill[19,4] 0.91 0.28 1.00 1.00 1.00 4955.48 1.00 skill[19,5] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[19,6] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[20,0] 1.00 0.00 1.00 1.00 1.00 nan nan skill[20,1] 1.00 0.00 1.00 1.00 1.00 nan nan skill[20,2] 0.51 0.50 1.00 0.00 1.00 15673.64 1.00 skill[20,3] 0.95 0.22 1.00 1.00 1.00 3068.93 1.00 skill[20,4] 1.00 0.05 1.00 1.00 1.00 nan 1.00 skill[20,5] 1.00 0.00 1.00 1.00 1.00 nan nan skill[20,6] 1.00 0.02 1.00 1.00 1.00 nan 1.00 skill[21,0] 1.00 0.04 1.00 1.00 1.00 4017.72 1.00 skill[21,1] 1.00 0.05 1.00 1.00 1.00 4026.94 1.00 skill[21,2] 0.09 0.29 0.00 0.00 0.00 4675.55 1.00 skill[21,3] 0.95 0.21 1.00 1.00 1.00 3005.92 1.00 skill[21,4] 0.91 0.28 1.00 1.00 1.00 4942.05 1.00 skill[21,5] 1.00 0.00 1.00 1.00 1.00 nan nan skill[21,6] 0.99 0.11 1.00 1.00 1.00 3700.23 1.00
ds = az.from_numpyro(mcmc_04)
az.plot_trace(ds);
log_density_model_04, model_04_trace = log_density(
model_04,
(jnp.array(responses), jnp.array(skills_key.astype(bool))),
{"prob_mistake": 0.1},
{key: value.mean(0) for key, value in mcmc_04.get_samples().items()},
)
pe_model_04 = mcmc_04.get_extra_fields()["hmc_state.potential_energy"]
exp_log_density_04 = log_ppd(
model_04,
mcmc_04.get_samples(),
jnp.array(responses),
jnp.array(skills_key.astype(bool)),
)
# post_loglik_04 = log_likelihood(
# model_04,
# mcmc_04.get_samples(),
# jnp.array(responses),
# jnp.array(skills_key.astype(bool)),
# )
# post_loglik_04_res = np.concatenate(
# [obs[:, None] for obs in post_loglik_04.values()], axis=1
# )
# exp_log_density_04 = logsumexp(post_loglik_04_res, axis=0) - jnp.log(
# jnp.shape(post_loglik_04_res)[0]
# )
neg_log_proba_model_04 = neg_log_proba_score(
mcmc_04.get_samples()["skill"].mean(0), self_assessed.values
)
print(
"Expected log joint density of model_00: {:.2f} +/- {:.2f}".format(
np.mean(-pe_model_03), np.std(-pe_model_03)
)
)
print(
"Expected log joint density of model_02: {:.2f} +/- {:.2f}".format(
np.mean(-pe_model_04), np.std(-pe_model_04)
)
)
Expected log joint density of model_00: -618.59 +/- 4.09 Expected log joint density of model_02: -652.80 +/- 6.83
plot_bars(
np.array([np.mean(pe_model_03), np.mean(pe_model_04)])[None, :],
["Original", "Learned"],
["Overall"],
tick_step=50.0,
ylabel="negative Expected log density",
yerr=[[np.std(pe_model_03)], [np.std(pe_model_04)]],
)
print(
"Expected log joint density of model_03 from `log_density`: {:.2f}".format(
log_density_model_03
)
)
print(
"Expected log joint density of model_04 from `log_density`: {:.2f}".format(
log_density_model_04
)
)
Expected log joint density of model_03 from `log_density`: -679.96 Expected log joint density of model_04 from `log_density`: -682.77
plot_bars(
-np.array([log_density_model_03, log_density_model_04])[None, :],
["Original", "Learned"],
["Overall"],
tick_step=50.0,
ylabel="negative log density",
yerr=[[np.std(pe_model_03)], [np.std(pe_model_04)]],
)
pd.DataFrame(
np.array([np.sum(exp_log_density_03), np.sum(exp_log_density_04)])[None, :],
columns=["Original", "Learned"],
index=["Overall"],
).plot(
kind="bar", color=["b", "r"], ylabel="Log pointwise predictive density",
)
<AxesSubplot:ylabel='Log pointwise predictive density'>
plot_bars(
np.array([neg_log_proba_model_03.mean(), neg_log_proba_model_04.mean()])[None, :],
["Original", "Learned"],
["Overall"],
ylabel="Negative Log Probability",
)
plot_bars(
np.concatenate(
[
neg_log_proba_model_03.mean(0)[:, None],
neg_log_proba_model_04.mean(0)[:, None],
],
axis=1,
),
["Original", "Learned"],
["Core", "OOP", "Life Cycle", "Web Apps Skills", "Desktop apps", "SQL", "C#"],
ylabel="Negative Log Probability",
)
def get_proba(
posterior_samples: Dict, params_sites: List[str], negative_log_proba=False
):
"""
:param posterior_samples Dict: dictionary of samples from the posterior.
:param params_sites List[str]: a list of params to compute proba
:param negative_log_proba bool: flag to return either probability or negative log probability
"""
proba = np.zeros((len(params_sites), posterior_samples[params_sites[0]].shape[-1]))
for i, param in enumerate(params_sites):
proba[i, :] = np.mean(posterior_samples[param], axis=0)
if negative_log_proba:
proba[proba == 0.0] = np.finfo(float).eps
proba = -np.log(proba)
return proba
self_assessed.astype(int).values.dtype
dtype('int64')
self_assessed.astype(int).values.T.shape
(7, 22)
proba = get_proba(
mcmc_00.get_samples(),
["skill_" + str(i) for i in range(7)],
negative_log_proba=False,
).T
proba.shape
(22, 7)
proba_00 = get_proba(
mcmc_00.get_samples(),
["skill_" + str(i) for i in range(7)],
negative_log_proba=False,
).T
rv_00 = scipy.stats.bernoulli(proba_00)
proba_model_00 = rv_00.pmf(self_assessed.astype(int).values)
proba_model_00[proba_model_00 == 0.0] = np.finfo(float).eps
neg_log_proba_model_00 = -np.log(proba_model_00)
(-np.log(proba_model_00)).mean()
3.15628763824813
proba_02 = get_proba(
mcmc_02.get_samples(),
["skill_" + str(i) for i in range(7)],
negative_log_proba=False,
).T
rv_02 = scipy.stats.bernoulli(proba_02)
proba_model_02 = rv_02.pmf(self_assessed.astype(int).values)
proba_model_02[proba_model_02 == 0.0] = np.finfo(float).eps
neg_log_proba_model_02 = -np.log(proba_model_02)
(-np.log(proba_model_02)).mean()
0.9652352227384892
pd.DataFrame(
np.array([neg_log_proba_model_00.mean(), neg_log_proba_model_02.mean()])[None, :],
columns=["Original", "Learned"],
index=["Overall"],
).plot(kind="bar", color=["b", "r"])
<AxesSubplot:>
pd.DataFrame(
np.concatenate(
[
neg_log_proba_model_00.mean(0)[:, None],
neg_log_proba_model_02.mean(0)[:, None],
],
axis=1,
),
columns=["Original", "Learned"],
index=["Core", "OOP", "Life Cycle", "Web Apps Skills", "Desktop apps", "SQL", "C#"],
).plot(kind="bar", color=["b", "r"])
<AxesSubplot:>
-dist.Bernoulli(proba_02).log_prob(self_assessed.astype(int).values)
DeviceArray([[1.19209290e-07, 1.19209290e-07, 6.72364593e-01, 6.60196126e-01, 5.03595257e+00, 5.00148453e-04, 1.19209290e-07], [7.50286621e-04, 2.22456250e-02, 2.47991896e+00, 4.84508336e-01, 9.73372310e-02, 1.17549435e-38, 1.61343142e-01], [1.03028929e-02, 4.25905688e-03, 1.23256491e-02, 4.89798278e-01, 6.70896530e-01, 1.24132860e+00, 6.84187353e-01], [1.25081057e-03, 2.50013138e-04, 1.24146998e-01, 3.12999904e-01, 4.31684963e-02, 1.75157762e+00, 8.29412174e+00], [1.19209290e-07, 1.19209290e-07, 7.23091066e-01, 1.13448694e-01, 5.24953175e+00, 2.33821225e+00, 1.19209290e-07], [1.17549435e-38, 3.25529277e-03, 7.50281382e-04, 2.50031269e-04, 1.00050040e-03, 1.17549435e-38, 6.90650225e-01], [1.19209290e-07, 2.50013138e-04, 7.12839842e-01, 7.32407868e-01, 2.41631341e+00, 7.60085535e+00, 6.21462059e+00], [4.25904663e-03, 1.19209290e-07, 4.26869774e+00, 1.15972176e-01, 7.36465454e-02, 5.20300388e+00, 5.76659571e-03], [8.31099078e-02, 1.19209290e-07, 7.24637866e-01, 2.80098796e+00, 5.26380679e-03, 2.50312779e-03, 1.19209290e-07], [1.19209290e-07, 1.19209290e-07, 2.09683049e-02, 2.88240361e+00, 4.89285326e+00, 1.80623010e-01, 2.50013138e-04], [1.25787705e-02, 1.19209290e-07, 9.59603861e-02, 2.77659655e+00, 4.88982648e-01, 5.63874841e-01, 1.19209290e-07], [2.20727491e+00, 1.67932403e+00, 7.17952311e-01, 7.50281382e-04, 1.13168687e-01, 5.00125054e-04, 1.15691476e-01], [7.50286621e-04, 5.00148453e-04, 7.16927707e-01, 7.26187050e-01, 4.78842556e-01, 4.44390345e+00, 8.29412174e+00], [1.26940060e+00, 3.00453021e-03, 2.17375231e+00, 9.49330509e-01, 5.68349361e-02, 1.21234095e+00, 1.85789919e+00], [3.90007682e-02, 1.23873687e+00, 1.48210019e-01, 7.36054778e-01, 3.14368337e-01, 4.50986242e+00, 1.70787811e+00], [1.19209290e-07, 1.19209290e-07, 8.99247080e-02, 6.90776777e+00, 1.16796243e+00, 2.50013138e-04, 5.00148453e-04], [1.19209290e-07, 1.19209290e-07, 1.94191724e-01, 7.39714861e-01, 2.03065133e+00, 4.50986242e+00, 1.19209290e-07], [7.50286621e-04, 1.19209290e-07, 7.18978047e-01, 7.39191055e-01, 2.63805795e+00, 3.35472733e-01, 1.19209290e-07], [8.28421768e-03, 1.50113669e-03, 7.17952311e-01, 2.85163212e+00, 1.18580544e+00, 3.39677334e-01, 1.37539402e-01], [1.19209290e-07, 1.19209290e-07, 1.35919284e-02, 5.84239215e-02, 2.40517139e+00, 1.19209290e-07, 1.19209290e-07], [1.19209290e-07, 1.19209290e-07, 7.19491184e-01, 5.55127300e-02, 6.26963703e-03, 1.19209290e-07, 2.50013138e-04], [5.00148453e-04, 4.76133032e-03, 2.32278776e+00, 6.02809191e-02, 2.42758179e+00, 5.00148453e-04, 4.28671503e+00]], dtype=float32)
neg_log_proba_model_00 = get_proba(
mcmc_00.get_samples(),
["skill_" + str(i) for i in range(7)],
negative_log_proba=True,
).mean(1)
neg_log_proba_model_02 = get_proba(
mcmc_02.get_samples(),
["skill_" + str(i) for i in range(7)],
negative_log_proba=True,
).mean(1)
pd.DataFrame(
np.array([neg_log_proba_model_00.mean(), neg_log_proba_model_02.mean()])[None, :],
columns=["Original", "Learned"],
index=["Overall"],
).plot(
kind="bar",
color=["b", "r"],
yerr=[[neg_log_proba_model_00.std()], [neg_log_proba_model_02.std()]],
)
<AxesSubplot:>
neg_log_proba_model_00_std = get_proba(
mcmc_00.get_samples(),
["skill_" + str(i) for i in range(7)],
negative_log_proba=True,
).std(1)
neg_log_proba_model_02_std = get_proba(
mcmc_02.get_samples(),
["skill_" + str(i) for i in range(7)],
negative_log_proba=True,
).std(1)
std = np.concatenate(
[neg_log_proba_model_00_std[:, None], neg_log_proba_model_02_std[:, None]], axis=1
).T
pd.DataFrame(
np.concatenate(
[neg_log_proba_model_00[:, None], neg_log_proba_model_02[:, None]], axis=1
),
columns=["Original", "Learned"],
index=["Core", "OOP", "Life Cycle", "Web Apps Skills", "Desktop apps", "SQL", "C#"],
).plot(kind="bar", color=["b", "r"])
<AxesSubplot:>
pd.DataFrame(
np.concatenate(
[neg_log_proba_model_00[:, None], neg_log_proba_model_02[:, None]], axis=1
),
columns=["Original", "Learned"],
index=["Core", "OOP", "Life Cycle", "Web Apps Skills", "Desktop apps", "SQL", "C#"],
).plot(kind="bar", color=["b", "r"], yerr=std)
<AxesSubplot:>
[
-np.log(mcmc_00.get_samples()[s].mean(0)).mean()
for s in ["skill_" + str(i) for i in range(7)]
]
<ipython-input-78-9f20df540015>:2: RuntimeWarning: divide by zero encountered in log -np.log(mcmc_00.get_samples()[s].mean(0)).mean()
[inf, 0.4987566, inf, inf, 0.7172846, inf, 0.17295544]
res_00 = np.zeros((7, 22))
for i in range(7):
s = "skill_" + str(i)
res_00[i, :] = np.mean(mcmc_00.get_samples()[s], axis=0)
res_00[res_00 == 0.0] = np.finfo(float).eps
neg_log_proba_model_00 = (-np.log(res_00)).mean(1)
res_02 = np.zeros((7, 22))
for i in range(7):
s = "skill_" + str(i)
res_02[i, :] = np.mean(mcmc_02.get_samples()[s], axis=0)
res_02[res_02 == 0.0] = np.finfo(float).eps
neg_log_proba_model_02 = (-np.log(res_02)).mean(1)
np.array([(-np.log(res_00)).mean(), (-np.log(res_02)).mean()])[None, :]
array([[1.56985962, 1.5088825 ]])
np.finfo(float).eps
2.220446049250313e-16
pd.DataFrame(
np.array([(-np.log(res_00)).mean(), (-np.log(res_02)).mean()])[None, :],
columns=["Original", "Learned"],
index=["Overall"],
).plot(kind="bar", color=["b", "r"])
<AxesSubplot:>
pd.DataFrame(
np.concatenate(
[neg_log_proba_model_00[:, None], neg_log_proba_model_02[:, None]], axis=1
),
columns=["Original", "Learned"],
index=["Core", "OOP", "Life Cycle", "Web Apps Skills", "Desktop apps", "SQL", "C#"],
).plot(kind="bar", color=["b", "r"])
<AxesSubplot:>
[
-np.log(mcmc_02.get_samples()[s].mean(0)).mean()
for s in ["skill_" + str(i) for i in range(7)]
]
<ipython-input-85-ceafa89bc0dc>:2: RuntimeWarning: divide by zero encountered in log -np.log(mcmc_02.get_samples()[s].mean(0)).mean()
[inf, 0.5350453, 1.688612, 1.184382, 1.1330117, inf, 0.19561504]
neg_log_proba_model_00 = np.array(
[
-np.log(mcmc_00.get_samples()[s].mean(0).mean())
for s in ["skill_" + str(i) for i in range(7)]
]
)
neg_log_proba_model_02 = np.array(
[
-np.log(mcmc_02.get_samples()[s].mean(0).mean())
for s in ["skill_" + str(i) for i in range(7)]
]
)
neg_log_proba_model_00 = np.array(
[
-np.log(mcmc_00.get_samples()[s].mean(0).mean())
for s in ["skill_" + str(i) for i in range(7)]
]
)
neg_log_proba_model_02 = np.array(
[
-np.log(mcmc_02.get_samples()[s].mean(0).mean())
for s in ["skill_" + str(i) for i in range(7)]
]
)
pd.DataFrame(
np.concatenate(
[neg_log_proba_model_00[:, None], neg_log_proba_model_02[:, None]], axis=1
),
columns=["Original", "Learned"],
index=["Core", "OOP", "Life Cycle", "Web Apps", "Desktop apps", "SQL", "C#"],
).plot(kind="bar", color=["b", "r"])
<AxesSubplot:>