import numpy as np
import scipy.optimize
import matplotlib.pyplot as plt
f = lambda a: np.exp(-24*a) - 0.5 * np.exp(-50*a) - 0.5
fig, ax = plt.subplots(figsize=(8,5))
a_min, a_max = 0, 0.01
a = np.linspace(0, 0.01, 100)
ax.plot(a, f(a))
ax.hlines(0, a_min, a_max, linestyles=':')
plt.show()
a_star = scipy.optimize.brentq(f, 0.002, 0.004)
a_star
0.0032034184388007483
CE_a = 20
def CE(x, p, a):
x = np.asarray(x)
p = np.asarray(p)
return -np.log(p @ np.exp(-a * x)) / a
CE([20], [1], a_star)
19.999999999999982
CE_b = CE([50, -10], [0.6, 0.4], a_star)
CE_b
24.600322886059406
CE_c = CE([50, 0, 10], [0.4, 0.3, 0.3], a_star)
CE_c
22.204239619220147