x_min = -5
x_max = 5
num_points = 1000
xs = np.linspace(x_min, x_max, num_points)
ys = xs
X, Y = np.meshgrid(xs, ys)
num_models = 2
y_value = 1
x_plot_const = 2
y_plot_const = 3
mean0 = 0.0
# std = 1.0
std0s = np.linspace(0.5, 1, num_models)
stds = np.linspace(1, 1.2, num_models)
# std0s = np.linspace(0.5, 1.5, num_models)
fig, axs = plt.subplots(1, num_models, figsize=(5.8, 2.8), dpi=300, sharex=True, sharey=True)
for std0, std, ax in zip(std0s, stds, axs):
Z = mlab.bivariate_normal(X, Y, sigmax=std0**2, sigmay=std0**2 + std**2, mux=mean0, muy=mean0, sigmaxy=std0**2)
var_post = (std0**(-2) + std**(-2))**(-1)
std_post = np.sqrt(var_post)
mean_post = var_post * (mean0 / std0**2 + y_value / std**2)
ax.contour(X, Y, Z)
ax.plot(xs, scipy.stats.norm.pdf(xs, loc=mean0, scale=std0) * x_plot_const + x_min, color='black')
axs[0].plot(scipy.stats.norm.pdf(ys, loc=mean0, scale=np.sqrt(std0**2 + std**2)) * y_plot_const + x_min, ys, color='black')
ax.plot(xs, scipy.stats.norm.pdf(xs, loc=mean_post, scale=std_post) * x_plot_const + x_min, color='black', linestyle='dashed')
ax.axhline(y=y_value, color='black')
ax.set_yticks([])
ax.set_xticks([])
ax.set_xlabel('$x$')
sns.despine(ax=axs[0])
sns.despine(ax=axs[1], top=True, left=True, right=True)
axs[0].text(-3.7, 0, '$p(y | \mathcal M_1)$', verticalalignment='center')
axs[0].text(-4.5, -2.5, '$p(y | \mathcal M_2)$')
axs[0].text(2.2, -4, '$p(x | y, \mathcal M_1)$', horizontalalignment='center')
axs[0].text(-1.3, -4, '$p(x | \mathcal M_1)$', horizontalalignment='center')
axs[1].text(2.2, -4.6, '$p(x | y, \mathcal M_2)$')
axs[1].text(-3.5, -4.6, '$p(x | \mathcal M_2)$')
axs[0].set_ylabel('$y$')
fig.tight_layout()
fig.savefig('occam.pdf', bbox_inches='tight')
fig.savefig('occam.png', bbox_inches='tight')