#!/usr/bin/env python # coding: utf-8 # In[1]: import numpy as np import scipy.optimize import matplotlib.pyplot as plt # In[2]: f = lambda a: np.exp(-24*a) - 0.5 * np.exp(-50*a) - 0.5 # In[3]: fig, ax = plt.subplots(figsize=(8,5)) a_min, a_max = 0, 0.01 a = np.linspace(0, 0.01, 100) ax.plot(a, f(a)) ax.hlines(0, a_min, a_max, linestyles=':') plt.show() # In[4]: a_star = scipy.optimize.brentq(f, 0.002, 0.004) a_star # In[5]: CE_a = 20 # In[6]: def CE(x, p, a): x = np.asarray(x) p = np.asarray(p) return -np.log(p @ np.exp(-a * x)) / a # In[7]: CE([20], [1], a_star) # In[8]: CE_b = CE([50, -10], [0.6, 0.4], a_star) CE_b # In[9]: CE_c = CE([50, 0, 10], [0.4, 0.3, 0.3], a_star) CE_c # In[ ]: