In [1]:
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
import numpy as np
import scipy.stats
import seaborn as sns

SMALL_SIZE = 5.5
MEDIUM_SIZE = 7
BIGGER_SIZE = 11

plt.switch_backend('agg')
plt.rc('font', size=MEDIUM_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)     # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
plt.rc('axes', linewidth=0.5)            # set the value globally
plt.rc('lines', linewidth=0.7)           # line thickness
plt.rc('xtick.major', width=0.5)            # set the value globally
plt.rc('ytick.major', width=0.5)            # set the value globally
plt.rc('ytick.major', size=2)            # set the value globally
plt.rc('xtick.major', size=0)            # set the value globally
plt.rc('text', usetex=True)
plt.rc('text.latex',
       preamble=[r'\usepackage{amsmath}',
                 r'\usepackage[cm]{sfmath}'])
plt.rc('font', **{'family': 'sans-serif', 'sans-serif': ['cm']})
plt.rc('axes', titlepad=3)

% matplotlib inline
In [2]:
x_min = -5
x_max = 5
num_points = 1000
xs = np.linspace(x_min, x_max, num_points)
ys = xs
X, Y = np.meshgrid(xs, ys)
num_models = 2
y_value = 1
x_plot_const = 2
y_plot_const = 3

mean0 = 0.0
# std = 1.0
std0s = np.linspace(0.5, 1, num_models)
stds = np.linspace(1, 1.2, num_models)
# std0s = np.linspace(0.5, 1.5, num_models)
fig, axs = plt.subplots(1, num_models, figsize=(5.8, 2.8), dpi=300, sharex=True, sharey=True)

for std0, std, ax in zip(std0s, stds, axs):
    Z = mlab.bivariate_normal(X, Y, sigmax=std0**2, sigmay=std0**2 + std**2, mux=mean0, muy=mean0, sigmaxy=std0**2)
    var_post = (std0**(-2) + std**(-2))**(-1)
    std_post = np.sqrt(var_post)
    mean_post = var_post * (mean0 / std0**2 + y_value / std**2)

    ax.contour(X, Y, Z)
    ax.plot(xs, scipy.stats.norm.pdf(xs, loc=mean0, scale=std0) * x_plot_const + x_min, color='black')
    axs[0].plot(scipy.stats.norm.pdf(ys, loc=mean0, scale=np.sqrt(std0**2 + std**2)) * y_plot_const + x_min, ys, color='black')
    ax.plot(xs, scipy.stats.norm.pdf(xs, loc=mean_post, scale=std_post) * x_plot_const + x_min, color='black', linestyle='dashed')
    ax.axhline(y=y_value, color='black')   
    
    ax.set_yticks([])
    ax.set_xticks([])
    ax.set_xlabel('$x$')

sns.despine(ax=axs[0])
sns.despine(ax=axs[1], top=True, left=True, right=True)

axs[0].text(-3.7, 0, '$p(y | \mathcal M_1)$', verticalalignment='center')
axs[0].text(-4.5, -2.5, '$p(y | \mathcal M_2)$')
axs[0].text(2.2, -4, '$p(x | y, \mathcal M_1)$', horizontalalignment='center')
axs[0].text(-1.3, -4, '$p(x | \mathcal M_1)$', horizontalalignment='center')

axs[1].text(2.2, -4.6, '$p(x | y, \mathcal M_2)$')
axs[1].text(-3.5, -4.6, '$p(x | \mathcal M_2)$')

axs[0].set_ylabel('$y$')
fig.tight_layout()
fig.savefig('occam.pdf', bbox_inches='tight')
fig.savefig('occam.png', bbox_inches='tight')