In [1]:
import numpy as np
from scipy.stats import invgamma
from matplotlib import pyplot as plt
import seaborn as sns
%matplotlib inline

ベイズ線形回帰

データ生成

$$ \begin{align} y &= wx + b + \varepsilon \\ &= \bf{w}^\rm{T} \bf{x} + \varepsilon \\ \varepsilon &\sim \mathcal N(0, \sigma_n^2) \end{align} $$ 重みは$w = 2.0$,切片は$b = 1.0$,ノイズ標準偏差は$\sigma_n = 2.0$,データ数$N=100$とする.

In [2]:
def gen_data1d(N=100, w=2.0, b=1.0, sigma_n=2.0):
    np.random.seed(1)
    x_ = np.linspace(-5, 5,N)
    x = np.c_[np.ones(N), x_]
    w = np.array([b, w])[:, np.newaxis]
    eps = np.random.normal(0, sigma_n**2, N)[:, np.newaxis]
    y = np.matmul(x, w) + eps
    return x_[:, np.newaxis], y

x, y = gen_data1d()
true_y = 2.0 * x + 1.0

fontsize = 20
plt.figure(figsize=(8, 8))
plt.plot(x, true_y, color='black', lw=1, label='true model')
plt.scatter(x, y, color='b', s=100, label='data')
plt.xlim(-6, 6)
plt.ylim(-13, 13)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.xlabel('x', fontsize=fontsize)
plt.ylabel('y', fontsize=fontsize)
plt.legend(fontsize=fontsize)
Out[2]:
<matplotlib.legend.Legend at 0x113c84470>

予測モデル

$$ \begin{align} y &= \bf{w}^\rm{T} \bf{x} + \varepsilon \\ \varepsilon &\sim \mathcal N(0, \sigma_n^2) \\ \sigma_n^2 &\sim \mathcal {IG}(\frac{\alpha}{2}, \frac{\beta}{2}) \\ \bf{w} &\sim \mathcal N({\bf 0}, {\bf \Sigma_w}) \\ \end{align} $$ と仮定すると,事前分布は $$ \begin{align} p(\sigma_n^2) &= \mathcal {IG}(\frac{\alpha}{2}, \frac{\beta}{2}) \\ &= \frac{({\frac{\beta}{2})}^{\frac{\alpha}{2}}}{\Gamma (\frac{\alpha}{2})} (\sigma_n^2)^{-\frac{\alpha}{2}-1} \exp \left( -\frac{\frac{\beta}{2}}{\sigma_n^2} \right) ,\\ p({\bf w}) &= \mathcal N(\bf{0}, \Sigma_w) \\ &= \left(\frac{1}{2 \pi}\right)^{\frac{D}{2}} |{\bf \Sigma_w}|^{-\frac{1}{2}} \exp \left( -\frac{1}{2} {\bf w}^{\rm T} {\bf \Sigma_w}^{-1} {\bf w} \right). \end{align} $$ ただし,$D$は$\bf w$の次元で$D=2$.

尤度関数は $$ \begin{align} p(y | {\bf x}, {\bf w}, \sigma_n^2) &= \prod_{i=1}^N \mathcal N({\bf w}^{\rm T} {\bf x}_i, \sigma_n^2) \\ &= \left(\frac{1}{2 \pi}\right)^{\frac{N}{2}} (\sigma_n^2)^{-\frac{N}{2}} \exp \left\{ -\frac{1}{2 \sigma_n^2} \sum_{i=1}^N (y_i - {\bf w}^{\rm T} {\bf x}_i)^2 \right\}. \end{align} $$ 事後分布は, $$ \begin{align} p({\bf w}, \sigma_n^2 | y, {\bf x}) &= \frac{p(y | {\bf x}, {\bf w}, \sigma_n^2) p({\bf w}, \sigma_n^2)}{p(y | {\bf x})} \\ &\propto p(y | {\bf x}, {\bf w}, \sigma_n^2) p({\bf w}, \sigma_n^2) \\ &= p(y | {\bf x}, {\bf w}, \sigma_n^2) p({\bf w}) p(\sigma_n^2) \\ &= \left(\frac{1}{2 \pi}\right)^{\frac{N}{2}} (\sigma_n^2)^{-\frac{N}{2}} \exp \left\{ -\frac{1}{2 \sigma_n^2} \sum_{i=1}^N (y_i - {\bf w}^{\rm T} {\bf x}_i)^2 \right\} \left(\frac{1}{2 \pi}\right)^{\frac{D}{2}} |{\bf \Sigma_w}|^{-\frac{1}{2}} \exp \left( -\frac{1}{2} {\bf w}^{\rm T} {\bf \Sigma_w}^{-1} {\bf w} \right) \frac{({\frac{\beta}{2})}^{\frac{\alpha}{2}}}{\Gamma (\frac{\alpha}{2})} (\sigma_n^2)^{-\frac{\alpha}{2}-1} \exp \left( -\frac{\frac{\beta}{2}}{\sigma_n^2} \right) \\ &\propto (\sigma_n^2)^{-\frac{N}{2}} \exp \left\{ -\frac{1}{2 \sigma_n^2} \sum_{i=1}^N (y_i - {\bf w}^{\rm T} {\bf x}_i)^2 \right\} \exp \left( -\frac{1}{2} {\bf w}^{\rm T} {\bf \Sigma_w}^{-1} {\bf w} \right) (\sigma_n^2)^{-\frac{\alpha}{2}-1} \exp \left( -\frac{\frac{\beta}{2}}{\sigma_n^2} \right) \end{align} $$ となる.

ギブス・サンプリング

事後分布の積分が解析的に行えないので,MCMCによるサンプリングで事後分布に従う乱数を生成する.

$\bf w$の条件付き分布は, $$ \begin{align} p({\bf w} | y, {\bf x}, \sigma_n^2) &= \frac{p({\bf w}, \sigma_n^2 | y, {\bf x})}{ p(\sigma_n^2)} \\ &= \frac{p(y | {\bf x}, {\bf w}, \sigma_n^2) p({\bf w}) p(\sigma_n^2)}{p(y | {\bf x}) p(\sigma_n^2)} \\ &\propto p(y | {\bf x}, {\bf w}, \sigma_n^2) p({\bf w}) \\ &= \left(\frac{1}{2 \pi}\right)^{\frac{N}{2}} (\sigma_n^2)^{-\frac{N}{2}} \exp \left\{ -\frac{1}{2 \sigma_n^2} \sum_{i=1}^N (y_i - {\bf w}^{\rm T} {\bf x}_i)^2 \right\} \left(\frac{1}{2 \pi}\right)^{\frac{D}{2}} |{\bf \Sigma_w}|^{-\frac{1}{2}} \exp \left( -\frac{1}{2} {\bf w}^{\rm T} {\bf \Sigma_w}^{-1} {\bf w} \right) \\ &\propto \exp \left\{ -\frac{1}{2} ({\bf y} - {\bf Xw})^{\rm T} (\sigma_n^2 {\bf I})^{-1} ({\bf y} - {\bf Xw}) \right\} \exp \left( -\frac{1}{2} {\bf w}^{\rm T} {\bf \Sigma_w}^{-1} {\bf w} \right) \\ &\propto \exp \left\{ -\frac{1}{2} ({\bf w} - \bar{\bf w})^{\rm T} {\bf A} ({\bf w} - \bar{\bf w}) \right\} \\ &\sim {\mathcal N}(\bar{\bf w}, {\bf A}^{-1}). \end{align} $$ ただし,${\bf A} = \frac{1}{\sigma_n^2} {\bf X^{\rm T}X} + \Sigma_{\bf w}^{-1}, \;\; \bar{\bf w} = \frac{1}{\sigma_n^2} {\bf A}^{-1} {\bf X^{\rm T}y}$.

$\sigma_n^2$の条件付き分布は, $$ \begin{align} p(\sigma_n^2 | y, {\bf x}, {\bf w}) &= \frac{p({\bf w}, \sigma_n^2 | y, {\bf x})}{ p({\bf w}) } \\ &= \frac{p(y | {\bf x}, {\bf w}, \sigma_n^2) p({\bf w}) p(\sigma_n^2)}{p(y | {\bf x}) p({\bf w})} \\ &\propto p(y | {\bf x}, {\bf w}, \sigma_n^2) p(\sigma_n^2) \\ &= \left(\frac{1}{2 \pi}\right)^{\frac{N}{2}} (\sigma_n^2)^{-\frac{N}{2}} \exp \left\{ -\frac{1}{2 \sigma_n^2} \sum_{i=1}^N (y_i - {\bf w}^{\rm T} {\bf x}_i)^2 \right\} \frac{({\frac{\beta}{2})}^{\frac{\alpha}{2}}}{\Gamma (\frac{\alpha}{2})} (\sigma_n^2)^{-\frac{\alpha}{2}-1} \exp \left( -\frac{\frac{\beta}{2}}{\sigma_n^2} \right) \\ &\propto (\sigma_n^2)^{-\frac{N}{2}} \exp \left\{ -\frac{1}{2 \sigma_n^2} ({\bf y} - {\bf Xw})^{\rm T}({\bf y} - {\bf Xw}) \right\} (\sigma_n^2)^{-\frac{\alpha}{2}-1} \exp \left( -\frac{\frac{\beta}{2}}{\sigma_n^2} \right) \\ &= (\sigma_n^2)^{-\frac{N + \alpha}{2} -1} \exp \left\{ -\frac{1}{\sigma_n^2} \frac{({\bf y} - {\bf Xw})^{\rm T}({\bf y} - {\bf Xw}) + \beta}{2} \right\} \\ &\sim \mathcal {IG}\left(\frac{N + \alpha}{2}, \frac{({\bf y} - {\bf Xw})^{\rm T}({\bf y} - {\bf Xw}) + \beta}{2} \right). \end{align} $$

ギブス・サンプリングでは,$\sigma_n^2$を固定したと考えて$\bf w$の条件付き分布$p({\bf w} | y, {\bf x}, \sigma_n^2)$からのサンプリングと,$\bf w$を固定したと考えて$\sigma_n^2$の条件付き分布$p(\sigma_n^2 | y, {\bf x}, {\bf w})$からのサンプリングを交互に繰り返す.

  1. 初期値${\bf w}^{(0)}, {\sigma_n^2}^{(0)}$を決める.
  2. $t = 0, 1, 2, ..., $に対して,

    (i) ${\bf w}^{(t+1)}$を$p({\bf w} | y, {\bf x}, {\sigma_n^2}^{(t)})$からサンプリング.

    (ii) ${{\sigma_n^2}^{(t+1)}}$を$p({\sigma_n^2} | y, {\bf x}, {\bf w}^{(t)})$からサンプリング.

In [3]:
def GibbsSampling(X, y, N_sample=3000, iter_num=100, burn_in=0.3):
    np.random.seed(0)
    N = X.shape[0]
    D = X.shape[1]
    alpha = 2.
    beta = 2.
    sign2_0 = 10.0
    w_0 = np.zeros((D, 1))
    
    sign2_t = sign2_0
    w_t = w_0
    w = [w_t]
    sign2 = [sign2_t]
    for t in range(N_sample):
        A = np.matmul(X.T, X) / sign2_t + np.eye(D)
        A_inv = np.linalg.inv(A)
        w_bar = np.matmul(np.matmul(A_inv, X.T), y)[:, 0] / sign2_t
        w_t = np.random.multivariate_normal(w_bar, A_inv)[:,np.newaxis]
        sign2_t= invgamma.rvs(a=(N + alpha) / 2., scale=(np.matmul((y - np.matmul(X, w_t)).T, (y - np.matmul(X, w_t)))) + beta / 2.)
        w.append(w_t)
        sign2.append(sign2_t)

    return np.array(w).reshape((N_sample + 1, D))[int(N_sample * burn_in):], np.array(sign2)[int(N_sample * burn_in):]

N = 100
X = np.c_[np.ones(N), x]
w, sign2 = GibbsSampling(X, y)

サンプリング結果

In [4]:
bins = 50
plt.hist(w[:, 0], bins=bins)
plt.title('$w_0$')

freq_w0, val_w0 = np.histogram(w[:, 0], bins=bins)
print(val_w0[np.argmax(freq_w0)])
N_sample = len(w[:, 0])
bottom = np.sort(w[:, 0])[int(N_sample * 0.025)]
top = np.sort(w[:, 0])[int(N_sample * 0.975)]
w0_ci = np.array([bottom, val_w0[np.argmax(freq_w0)], top])
print(bottom, top)
0.833603249623
0.105373189652 1.86834810922
In [5]:
plt.hist(w[:, 1], bins=bins)
plt.title('$w_1$')

freq_w1, val_w1 = np.histogram(w[:, 1], bins=bins)
print(val_w1[np.argmax(freq_w1)])
N_sample = len(w[:, 1])
bottom = np.sort(w[:, 1])[int(N_sample * 0.025)]
top = np.sort(w[:, 1])[int(N_sample * 0.975)]
w1_ci = np.array([bottom, val_w1[np.argmax(freq_w1)], top])
print(bottom, top)
2.02445411147
1.7301797938 2.40666919289
In [6]:
plt.hist(sign2, bins=bins)
plt.title('$\sigma_n^2$')
Out[6]:
<matplotlib.text.Text at 0x114980c50>

予測分布

In [7]:
X = np.c_[np.ones(N), x]    # (100, 2)
W = np.c_[w0_ci, w1_ci]    # (3, 2)
pred_y = np.matmul(W, X.T)

fontsize = 20
plt.figure(figsize=(8, 8))
plt.plot(x, true_y, color='black', lw=1, label='true model')
plt.scatter(x, y, color='b', s=100, label='data')

plt.fill_between(x[:, 0], pred_y[0], pred_y[2], alpha=0.3, label='95% Bayesian CI')
plt.plot(x, pred_y[1], label='prediction', lw=3)

plt.xlim(-6, 6)
plt.ylim(-13, 13)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.xlabel('x', fontsize=fontsize)
plt.ylabel('y', fontsize=fontsize)
plt.legend(fontsize=fontsize)
Out[7]:
<matplotlib.legend.Legend at 0x1142e2cf8>