In [1]:
#from IPython.core.display import HTML
#import urllib
#response = urllib.urlopen("""
#    https://gist.githubusercontent.com/sgttwld/c060b18a9d6ce7c3a10e3c6dce2c0d3a/raw
#""")
#css = str(response.read().decode("utf-8"))
#HTML("<style type='text/css'>"+css+"</style>")
In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats
sns.set()
%config InlineBackend.figure_format = 'retina'
import _lib.pr_func as pr
import _lib.BA_class as BA


def get_ll(px,D):
    ll = 0
    for d in D:
        ll += np.log(px.val[d])
    return ll

def validate_ll(pmu,D,best_ll):
    px = pr.sum(['mu'],pmu*px_mu)
    ll = get_ll(px,D)
    return best_ll - ll

def validate_px(pmu,px_real):
    px = pr.sum(['mu'],pmu*px_mu)
    return np.linalg.norm(px.val-px_real.val)

Bayesian inference for mixture distributions

Sebastian Gottwald, 2018

Problem statement

Let $X$ be an observable that is distributed according to a mixture of components with known distributions $p(x|\mu)$. The goal is to find the parameters $\mu$ from which the data was generated.

The mixture can be written as $$ p(x) = \sum_\mu \, p(\mu) \, p(x|\mu) $$ and the goal would be to find $p(\mu)$. However, if the number of mixture components $M$ is known we can split up $p(\mu)$ into $\{p(\mu|z)\}_{z=1}^M$:

\begin{equation}\tag{$\ast$} p(x) = \sum_{z=1}^M p(z) \sum_\mu p(\mu|z) \, p(x|\mu) \end{equation}

For simplicity, we assume equal mixture coefficients, i.e. $\rho_z:= p(z)=1/M$ for all $z$.

Examplary data

Let $M=4$, $p(x|\mu) = \mathcal{N}(x\,|\,\mu,\sigma^2)$ with $\sigma^2=25$, and $p_{real}(\mu) = \frac{1}{4}(\delta_{\mu,10} + \delta_{\mu,20} + \delta_{\mu,40} + \delta_{\mu,50})$

In [229]:
N,M,K = 60,4,60

sigma = 25.0
datasize = 1000

pr.set_dims([('x',N),('z',M),('mu',K)])
rho = np.ones(M)/M
pz = pr.func('f(z)',val=rho)

def f(K,mean,sigma):
    k = np.arange(0,K)
    return np.exp(-(k-mean)**2/sigma)

def g(N,K,sigma,offset):
    llh = np.zeros([N,K])
    for i in range(0,N):
        llh[i,:] = f(K,i-offset,sigma)
    return llh

def sample(px_mat):
    return np.random.choice(range(0,len(px_mat)),p=px_mat)

# likelihood
px_mu = pr.func('f(x,mu)', val=g(N,K,sigma=sigma,offset=0.0)).normalize(['x'])

# real p(mu)
if M == 2:
    pmu_real = pr.func('f(mu)', val=f(K,mean=10,sigma=.1)+f(K,mean=30,sigma=.1)).normalize()
elif M == 3:
    pmu_real = pr.func('f(mu)', val=f(K,mean=10,sigma=.1)+f(K,mean=30,sigma=.1)+f(K,mean=40,sigma=.1)).normalize()
elif M==4:
    means = [10,20,40,50]
    mat = 0
    for mean in means:
        mat += f(K,mean=mean,sigma=.1)
    pmu_real = pr.func('f(mu)', val=mat).normalize()
elif M==6:
    means = [15,25,35,65,75,85]
    mat = 0
    for mean in means:
        mat += f(K,mean=mean,sigma=.1)
    pmu_real = pr.func('f(mu)', val=mat).normalize()

px_real = pr.sum(['mu'],px_mu*pmu_real)

# create data
D = [sample(px_mu.val[:,sample(pmu_real.val)]) for d in range(0,datasize)]

best_ll = get_ll(px_real,D)

freq = np.zeros(N)
for d in D:
    for i in range(0,N):
        if d == i: freq[i] += 1

PY = freq/np.sum(freq)

with plt.rc_context({"figure.figsize": (18,5)}):
    fig,ax = plt.subplots(1,3)
    ax[0].set_title('$p(x|\mu)$')
    ax[0].pcolor(px_mu.val)
    ax[0].set_xlabel('$\mu$')
    ax[0].set_ylabel('$x$')
    ax[1].set_title('$p_{real}(\mu)$')
    ax[1].bar(range(K),pmu_real.val)
    ax[1].set_xlabel('$\mu$')
    ax[2].set_title('frequencies of $d=x$ in $D$, $|D|={}$'.format(datasize))
    ax[2].bar(range(0,len(PY)),PY)
    ax[2].set_xlabel('$x$')
    plt.show()

1. Approximating $p(\mu)=\sum_x p(\mu|x) p(x)$

Here, we generate $p(\mu)$ from $\{p(\mu|d)\}_{d\in D}$ iteratively, as follows:

  1. initialize $t=0$ and $p^{(0)}(\mu) = 1/K$ for all $\mu$ (uniform)
  2. For each $d\in D$ calculate $p^{(t)}(\mu|d) = \frac{p(d|\mu)\, p^{(t)}(\mu)}{\sum_{\mu} p(d|\mu) \, p^{(t)}(\mu)}$
  3. Effectively approximate $p(\mu) = \sum_{x} p(\mu|x) p(x)$ by $\sum_{x} p(\mu|x)\, f_D(x)$, where $f_D(x)$ is the frequency of $x$ in $D$: $$p^{(t+1)}(\mu) = \frac{1}{|D|} \sum_{d\in D} p^{(t)}(\mu|d)$$
  4. Set $t=t+1$ and continue with 2.
In [230]:
datasize = 1000
batchsize = 100
epochs = 10

# init
pmu = pr.func('f(mu)', val='unif').normalize()
F_ll,F_px = [],[]

# inference to find marginal
for i in range(0,epochs):
    for n in range(int(datapoints/batchsize)):
        D0 = D[n*batchsize:(n+1)*batchsize]
        pmu_ds = [(px_mu.eval('x',d)*pmu).normalize().val for d in D0]
        pmu_marg = 0
        for pmu_d in pmu_ds:
            pmu_marg += pmu_d
        pmu.val = pmu_marg/len(pmu_ds)
        F_ll.append(validate_ll(pmu,D,best_ll))
        F_px.append(validate_px(pmu,px_real))


with plt.rc_context({"figure.figsize": (18,4)}):
    fig,ax = plt.subplots(1,3)
    fig.suptitle('datasize: {}, batchsize: {}, epochs: {}'.format(datapoints,batchsize,epochs))
    ax[0].bar(range(0,len(pmu.val)),pmu.val)
    ax[0].set_xlabel('$\mu$')
    ax[0].set_title('$p(\mu)$')
    ax[1].scatter(range(len(F_ll)),F_ll,marker='.')
    #ax[1].set_ylim([0,1000])
    ax[1].set_title('$LL_D[p_{real}]-LL_D[p]$')
    ax[2].scatter(range(len(F_px)),F_px,marker='.')
    #ax[2].set_ylim([0,.05])
    ax[2].set_title('$||p(X)-p_{real}(X)||$')
    fig.subplots_adjust(top=0.82)
    plt.show()

fitness_ll = validate_ll(pmu,D,best_ll)
fitness_px = validate_px(pmu,px_real)

print fitness_ll,fitness_px
9.98677666392 0.0160731370158

2. EM algorithm

Here, we assume $p(\mu|z) = \delta_{\mu_z,\mu}$, where $\{\mu_z\}_{z=1}^M$ are the unknown parameters we want to find, so that $(\ast)$ becomes $$ p(x) = \sum_{z=1}^M\rho_z \sum_\mu \delta_{\mu_z,\mu}\, p(x|\mu) = \sum_{z=1}^M \rho_z \, p(x|\mu_z) $$

The EM algorithm for Gaussian mixtures is as follows:

  1. Initialize $t=0$ and $\{\mu_z^{(0)}\}_{z=1}^M$ randomly

  2. (E-step) For each $d\in D$ calculate $$ a_z(d) := p(z|d,\mu_1^{(t)},\dots,\mu_M^{(t)}) = \frac{\rho_z\, p(d|\mu^{(t)}_z)}{\sum_{j=1}^M \rho_j \, p(d|\mu^{(t)}_j)} \quad \forall z\in \{1,\dots, M\} $$

  3. (M-step) Determine $\{\mu^{(t+1)}_z\}_{z=1}^M$ from $\{a_z\}_{z=1}^M$ as follows: $$ \mu_z^{(t+1)} = \frac{\sum_{d\in D} a_z(d) \cdot d}{\sum_{d'\in D} a_z(d')} \quad \forall z\in \{1,\dots, M\} $$

  4. Set $t=t+1$ and continue with 2.

Note: One can do batches by running one EM-cycle per batch $D_i\subset D$ for $i=1,\dots,B$, where $B$ is the number of batches.

In [244]:
datapoints = 1000
batchsize = 1000
epochs = 30

# fold
pr.set_dims([('x',N),('z',M),('mu',K)])

def EM_step(mus,pxgmu,D0):
    mus = [int(mu) for mu in mus]
    pxgmuz = [pxgmu[:,mu] for mu in mus]
    # E-step
    pzgd = []
    for i in range(0,len(D0)):
        d = D0[i]
        Z = 0
        v = np.zeros(M)
        for z in range(0,M):
            v[z] = pxgmuz[z][d]*rho[z]
        pzgd.append(v/np.sum(v))
    # M-step
    pzgd = np.array(pzgd)
    mus = np.einsum('ij,i,j->j',pzgd,D0, 1.0/np.einsum('ij->j',pzgd))
    return mus

def get_pmu_from_mus(mus,rho):
    pmu = np.zeros(K)
    for j in range(0,K):
        for z in range(0,M):
            if j == int(mus[z]):
                pmu[j] = rho[z]
    return pmu

# init mu1,...
mus = np.random.choice(range(K),M)
F_ll, F_px = [],[]

# EM steps:
for i in range(0,epochs):
    for n in range(int(datapoints/batchsize)):
        D0 = D[n*batchsize:(n+1)*batchsize]
        mus = EM_step(mus,px_mu.val,D0)

        # validate
        pmu = pr.func('f(mu)',val=get_pmu_from_mus(mus,pz.val))
        F_ll.append(validate_ll(pmu,D,best_ll))
        F_px.append(validate_px(pmu,px_real))


with plt.rc_context({"figure.figsize": (18,4)}):
    s = 'EM-algorithm ($p(\mu)=\sum_z\, \\rho_z\, \delta_{\mu_z,\mu}$)'
    fig,ax = plt.subplots(1,3)
    fig.suptitle(s + ', datasize: {}, batchsize: {}, epochs: {}'.format(datapoints,batchsize,epochs))
    ax[0].bar(range(0,len(pmu.val)),pmu.val)
    ax[0].set_xlabel('$\mu$')
    ax[1].scatter(range(len(F_ll)),F_ll,marker='.')
    ax[1].set_ylim([0,1000])
    ax[2].set_ylim([0,.05])
    ax[2].scatter(range(len(F_px)),F_px,marker='.')
    ax[0].set_title('$p(\mu)$')
    ax[1].set_title('$LL_D[p_{real}]-LL_D[p]$')
    ax[2].set_title('$||p(X)-p_{real}(X)||$')
    fig.subplots_adjust(top=0.81)
    plt.show()

fitness_ll = validate_ll(pmu,D,best_ll)
fitness_px = validate_px(pmu,px_real)

print fitness_ll,fitness_px
13.5511405589 0.0207590704267

3. Bayesian inference with conjugate priors and sampling

We are inferring $p(\mu|z) = p(\mu_z)$ in $(\ast)$ by making use of conjugate priors and the ability to sample from parametrized and discrete distributions. Since the likelihood $p(x|\mu)=\mathcal N(x|\mu,\phi^{-1})$ is Gaussian (in our example, $\phi=\frac{1}{20}$), we pick a Gaussian prior over $\mu_j$ for $j=1,\dots,M$:

$$ p(\mu_j) = \mathcal N\big( m_j , \tfrac{1}{\phi \, \alpha_j}\big)(\mu_j) $$

and the goal is to determine $\{m_j\}_{j=1}^M$ as well as $\{\alpha_j\}_{j=1}^M$.

The algorithm works as follows:

  • Initialize $t=0$ and $m^{(0)}_j$, $\alpha^{(0)}_j$ for all $j=1,\dots, M$

  • For $D_{batch}\subset D$:

    1. Sample $\mu_j$ from $p^{(t)}(\mu_j) = \mathcal N(m_j^{(t)}, \tfrac{1}{\phi \, \alpha^{(t)}_j})(\mu_j)$ for all $j=1,\dots, M$

    2. Initialize $\{D_z\}_{z=0}^M$ as a collection of empty sets so that $D_z$ collects the $d$s that are proposed to correspond to $z$.

    3. For all $d\in D_{batch}$ sample $z_d\sim p(z|d,\mu_1,\dots, \mu_M) = \frac{\rho_z \, p(d|\mu_z)}{\sum_j \rho_j\, p(d|\mu_j)}$ and add $d$ to $D_{z_d}$.

    4. From $\{D_z\}_{z=1}^M$ determine occupation numbers $\{n_z\}_{z=0}^M$ and sums $\{S_z\}_{z=0}^M$ by

      $$ n_z = |D_z| , \quad S_z = \sum_{d\in D_z} d \qquad \forall z=1,\dots,M $$

    5. Update $p^{(t+1)}(\mu_j)\leftarrow p(\mu_j|D_j)$ for all $j=1,\dots,M$, where the posterior for $\mu_j$ is

      $$ p(\mu_j|D_j) = \mathcal N\big(m_j^{(t+1)}, \tfrac{1}{\phi\, \alpha_j^{(t+1)}}\big)(\mu_j) $$ where $$ m_j^{(t+1)} = \frac{\alpha_j^{(t)} m_j^{(t)} + S_j}{\alpha_j^{(t)} + n_j} \ , \qquad \alpha_j^{(t+1)} = \alpha_j^{(t)} + n_j $$

    6. Set $t=t+1$ and continue with $1$.

Question: How is this related to Gibbs-Sampling, which would be

  1. Initialize $\mu,z$
  2. Sample $x \sim p(x|\mu,z)$
  3. Sample $z \sim p(z|x,\mu)$
  4. Sample $\mu \sim p(\mu|x,z)$
  5. Go back to 2.

Possible Answer: Let $\mathbf x = (x_1,\dots,x_{|D_{batch}|}):= D_{batch}$ and do

  1. Initialize $\mu_1,\dots,\mu_M$
  2. Sample $\mathbf z=(z_1,...,z_{|D_{batch}|})$ via $z_d \sim p(z|d,\mu_1,\dots,\mu_M)$ for all $d\in D_{batch}$.
  3. Sample $\mu_1,\dots,\mu_M$ via $p(\mu_j|\mathbf z,\mathbf x) = p(\mu_j|D_j)$ (see above)
In [245]:
datapoints = 1000
batchsize = 10
epochs = 10


pr.set_dims([('x',N),('z',M),('mu',K)])


def f_px_mu(mu,x):
    return scipy.stats.norm.pdf(x,loc=mu,scale=np.sqrt(20))

def delta(i,j):
    if i==j: return 1
    else: return 0

def normal(t,m,sigmasqu):
    y = scipy.stats.norm.pdf(t,loc=m,scale=np.sqrt(sigmasqu))
    return y/np.sum(y)

phi = 1.0/sigma

# initialize
m = np.random.uniform(K/4.0,K*3.0/4.0,M)
alpha = np.ones(M)/(K*phi)
F_ll,F_px = [],[]

for l in range(epochs):
    for n in range(int(datapoints/batchsize)):
        D0 = D[n*batchsize:(n+1)*batchsize]
        Dparts = [[] for i in range(M)]
        for d in D0:
            mus = np.random.normal(m,1.0/(alpha*phi))
            #2 # calculate p(z|d,mu1,mu2,...)
            pzgd = rho*f_px_mu(mus,d)/np.sum(rho*f_px_mu(mus,d))
            z = np.random.choice(range(M),p=pzgd) # proposed z for the given d 
            #MAP opt: z = np.argmax(pzgd)
            Dparts[z].append(d)
        ## calc stats from Dparts
        occ = np.array([len(Dpart) for Dpart in Dparts])
        sum_delta = np.array([np.sum(Dpart) for Dpart in Dparts])
        ## update pars after batch:
        m = (alpha*m + sum_delta)/(alpha+occ)
        alpha += occ
        # for validation:
        pmu_z_mat = np.array([normal(np.arange(0,K),m[z],1.0/(phi*alpha[z])) for z in range(0,M)])
        pmu_z = pr.func('f(mu,z)',val=pmu_z_mat)
        F_ll.append(validate_ll(pr.sum(['z'],pmu_z*pz),D,best_ll))
        F_px.append(validate_px(pr.sum(['z'],pmu_z*pz),px_real))

# plot

with plt.rc_context({"figure.figsize": (18,5)}):
    fig,ax = plt.subplots(1,3)
    pre = 'Bayesian inference with conjugate priors ($p(\mu|z) = p(\mu_z)$), '
    fig.suptitle(pre + 'datasize: {}, batchsize: {}, epochs: {}'.format(datapoints,batchsize,epochs))
    for z in range(0,M):
        ax[0].bar(t,pmu_z.val[z],label='z={}'.format(z))
    ax[0].set_xlabel('$\mu$')
    ax[0].legend()
    ax[1].scatter(range(len(F_ll)),F_ll,marker='.')
    ax[1].set_ylim([0,1000])
    ax[2].set_ylim([0,.05])
    ax[2].scatter(range(len(F_px)),F_px,marker='.')
    ax[0].set_title('$p(\mu_z)$')
    ax[1].set_title('$LL_D[p_{real}]-LL_D[p]$')
    ax[2].set_title('$||p(X)-p_{real}(X)||$')
    fig.subplots_adjust(top=0.82)
 
    plt.show()

fitness_ll = validate_ll(pr.sum(['z'],pmu_z*pz),D,best_ll)
fitness_px = validate_px(pr.sum(['z'],pmu_z*pz),px_real)

print fitness_ll,fitness_px
32.5933555726 0.0272345006644

4. "2-step" Bayesian inference (no conjugate priors, but discrete)

Here, we infer $p(\mu|z)$ in $(\ast)$ without conjugate priors, directly from the two-step inference process $x \rightarrow z \rightarrow \mu$:

\begin{eqnarray} \tag{$i$} p(z|x) & = & \frac{p(z) \,p(x|z)}{\sum_{z'} p(z') \, p(x|z')} = \frac{p(z) \, \sum_\mu p(\mu|z) \, p(x|\mu)}{\sum_{z'} p(z') \, \sum_{\mu'} p(\mu'|z) \, p(x|\mu')} \\[10px] \tag{$ii$} p(\mu|z,x) & = & \frac{p(x,\mu|z)}{p(x|z)} = \frac{p(\mu|z) p(x|\mu)}{\sum_{\mu'} p(\mu'|z) p(x|\mu')} \end{eqnarray}

where we used that by assumption $p(x|\mu,z) = p(x|\mu)$.

Prior update

There are two obvious possiblities to update the prior $p(\mu|z)$:

(1) Posterior: $p(\mu|z) \leftarrow p(\mu|z,d)$ (assumes $d$ belongs to the particular $z$, enforces specialization)

(2) Marginal: $p(\mu|z) \leftarrow \sum_x p(\mu|z,x) p(x|z)$ (only works with batches or full $D$)

Assigning the component

Here, we compare two different ways to determine the component $z$ a given datapoint $d$ belongs to:

Sampling: $z\sim p(z|d)$

MAP: $z = \max_{z'}p(z'|d)$

(a) Most simplest version: posterior prior update, no batches

  1. Initialize $t=0$ and $p^{(0)}(\mu|z)$
  2. Selector step $(i)$: Given $d\in D$, calculate $p(z|d)=\frac{\rho_z \sum_\mu p^{(t)}(\mu|z) p(d|\mu)}{\sum_z \dots}$ and sample $z\sim p(z|d)$ or pick $z = \mathrm{argmax}_{z'}p(z'|d)$.
  3. Actor step $(ii)$: Determine $p(\mu|d,z) = \frac{p^{(t)}(\mu|z) \, p(d|\mu)}{\sum_\mu \dots}$ and set $p^{(t+1)}(\mu|z) = p(\mu|d,z)$
  4. Set $t=t+1$ and continue with 2.
In [234]:
def postupdate_nobatch_nobetas(D,px_mu,datapoints,epochs,opt='sample'):
    pz = pr.func('f(z)',val='unif').normalize()

    #init
    pmu_z = pr.func('f(z,mu)',val='unif').normalize(['mu'])
    F_ll, F_px = [],[]

    # inference
    for d in D[:datapoints]:
        pz_d = (pz*pr.sum(['mu'],pmu_z*px_mu.eval('x',d))).normalize()
        if opt == 'sample':
            opt_str = 'component-sampling ($z\sim p(z|d)$)'
            z = sample(pz_d.val)
        elif opt == 'MAP':
            opt_str = 'component-MAP ($z = argmax_{z} \,p(z|d)$)'
            z = np.argmax(pz_d.val)
        pmu_z.val[z,:] = px_mu.eval('x',d).val*pmu_z.val[z,:]
        pmu_z = pmu_z.normalize(['mu'])
        # evaluate
        F_ll.append(validate_ll(pr.sum(['z'],pmu_z*pz),D,best_ll))
        F_px.append(validate_px(pr.sum(['z'],pmu_z*pz),px_real))

    t = range(K)
    with plt.rc_context({"figure.figsize": (18,5)}):

        fig,ax = plt.subplots(1,3)
        fig.suptitle(opt_str +', datasize: {}, epochs: {}'.format(datapoints,epochs))
        for z in range(M):
            ax[0].bar(t,pmu_z.val[z,:],label='z={}'.format(z))
        ax[0].set_xlabel('$\mu$')
        ax[0].legend()
        ax[1].scatter(range(len(F_ll)),F_ll,marker='.')
        ax[1].set_ylim([0,1000])
        ax[2].set_ylim([0,.05])
        ax[2].scatter(range(len(F_px)),F_px,marker='.')
        ax[0].set_title('$p(\mu_z)$')
        ax[1].set_title('$LL_D[p_{real}]-LL_D[p]$')
        ax[2].set_title('$||p(X)-p_{real}(X)||$')
        fig.subplots_adjust(top=0.82)
        plt.show()

    print F_ll[-1],F_px[-1]
In [236]:
datapoints = 1000
epochs = 1

postupdate_nobatch_nobetas(D,px_mu,datapoints,epochs,opt='sample')
postupdate_nobatch_nobetas(D,px_mu,datapoints,epochs,opt='MAP')
-0.0639424400006 0.000329131483885
2.50033294131 0.0152364648336

(b) With batches

Let $D_{batch}\subset D$ and let $D_z\subset D_{batch}$ denote a proposal of datapoints that belong to component $z$, obtained in a given iteration by sampling $z_d\sim p(z|d)$ and adding $d$ to $D_{z_d}$ for each $d\in D_{batch}$. With this notation, the two possible prior updates, $(1)$ and $(2)$, are as follows:

(1) Update prior by posterior: This generalizes the non-batch version $(a)$ directly by replacing the update rule $p(\mu|z)\leftarrow p(\mu|z,d)$ by

$$ \tag{$1$} p(\mu|z)\leftarrow p(\mu|z,D_z) = \frac{p(\mu|z) \, p(D_z|\mu)}{\sum_{\mu'} p(\mu'|z) \, p(D_z|\mu')} = \frac{p(\mu|z) \, \prod_{d\in D_z} p(d|\mu)}{\sum_{\mu'} p(\mu'|z) \, \prod_{d\in D_z} p(d|\mu')} \qquad \forall z=1,\dots, M $$

(2) Update prior by marginal: Here, we approximate the expectation over $p(x|z)$ in the marginal $p(\mu|z) = \sum_x p(\mu|z,x) p(x|z)$ by the expectation w.r.t. the frequency

$$ f_{D_z}(x) = \frac{1}{|D_z|} \, \big|\big\{d\in D_z|d=x\big\}\big| $$

so that the prior update becomes

$$ \tag{$2$} p(\mu|z) \leftarrow \sum_x p(\mu|z,x) p(x|z) \approx \sum_x p(\mu|z,x) f_{D_z}(x) = \frac{1}{|D_z|} \sum_{d\in D_z} p(\mu|z,d) \qquad \forall z=1,\dots, M $$

Here is the algorithm:

  • Initialize $t=0$ and $p^{(0)}(\mu|z)$

  • For $D_{batch}\subset D$:

    1. Initialize $\{D_z\}_{z=0}^M$ as a collection of empty sets so that $D_z$ collects the $d$s that are proposed to correspond to $z$.

    2. (Selector step) For all $d\in D_{batch}$ sample $z_d\sim p(z|d)=\frac{\rho_z \sum_\mu p^{(t)}(\mu|z) p(d|\mu)}{\sum_z \dots}$ or pick $z = \mathrm{argmax}_{z'}p(z'|d)$ and add $d$ to $D_{z_d}$.

    3. Depending on the chosen prior update, we do either (1) or (2) from above:

      (1) (Actor step with posterior prior update) For all $z=0,...,M$, determine

      $$ p(\mu|D_z,z) = \frac{p^{(t)}(\mu|z) \, p(D_z|\mu)}{\sum_{\mu'} p^{(t)}(\mu'|z) \, p(D_z|\mu')} = \frac{p^{(t)}(\mu|z) \, \prod_{d\in D_z} p(d|\mu)}{\sum_{\mu'} p^{(t)}(\mu'|z) \, \prod_{d\in D_z} p(d|\mu')} $$

      and set $p^{(t+1)}(\mu|z) = p(\mu|D_z,z)$.

      (2) (Actor step with marginal prior update) For all $z=0,...,M$, determine $$p(\mu|z,d) = \frac{p^{(t)}(\mu|z) p(d|\mu)}{\sum_{\mu'} p^{(t)}(\mu'|z) p(d|\mu')} \qquad \forall d\in D_{z}$$ and set $p^{(t+1)}(\mu|z) = \frac{1}{|D_z|} \sum_{d\in D_z} p(\mu|d,z)$.

    4. Set $t=t+1$ and continue with 1 and the next batch.

In [201]:
def batch_nobetas(D,px_mu,datapoints,batchsize,epochs,pu='post',opt='sample'):

    pr.set_dims([('x',N),('z',M),('mu',K)])

    pz = pr.func('f(z)',val='unif').normalize()


    #init
    pmu_z = pr.func('f(z,mu)',val='unif').normalize(['mu'])
    F_ll,F_px = [],[]

    # inference
    for r in range(epochs):
        for n in range(int(datapoints/batchsize)):
            D0 = D[n*batchsize:(n+1)*batchsize]
            Dparts = [[] for j in range(M)]
            for d in D0:
                pz_d = (pz*pr.sum(['mu'],pmu_z*px_mu.eval('x',d))).normalize()   
                if opt == 'sample':
                    opt_str = 'component-sampling ($z\sim p(z|d)$)'
                    z = sample(pz_d.val)
                elif opt == 'MAP':
                    opt_str = 'component-MAP ($z = argmax_{z} \,p(z|d)$)'
                    z = np.argmax(pz_d.val)
                Dparts[z].append(d)
            for z in range(M):
                Dpart = Dparts[z]
                if len(Dpart) > 0:
                    if pu == 'post':
                        amu = np.einsum('ij->j',np.array([np.log(px_mu.eval('x',d).val) for d in Dpart]))
                        b = np.max(amu)
                        pmu_z.val[z,:] = pmu_z.val[z,:]*np.exp(amu-b)/np.sum(pmu_z.val[z,:]*np.exp(amu-b))
                    elif pu == 'marg':
                        pmu_zd_temp = [px_mu.eval('x',d).val*pmu_z.val[z,:] for d in Dpart] # p(d|mu)*p(mu|z)
                        pmu_zd = np.array([pmu_zdi/np.sum(pmu_zdi) for pmu_zdi in pmu_zd_temp]) #p(d|mu)*p(mu|z)/sum(...)
                        pmu_z.val[z,:] = np.einsum('ik->k',pmu_zd)/len(Dpart) #  1/|Dpart| sum_{d in Dpart} p(mu|z,d)
            pmu_z = pmu_z.normalize(['mu'])
            F_ll.append(validate_ll(pr.sum(['z'],pmu_z*pz),D,best_ll))
            F_px.append(validate_px(pr.sum(['z'],pmu_z*pz),px_real))


    t = range(K)
    with plt.rc_context({"figure.figsize": (18,5)}):

        fig,ax = plt.subplots(1,3)
        if pu =='post':
            pre = 'Posterior prior update, '
        elif pu =='marg':
            pre = 'Marginal prior update, '
        fig.suptitle(pre+opt_str+', datasize: {}, batchsize: {}, epochs: {}'.format(datapoints,batchsize,epochs))
        for z in range(M):
            ax[0].bar(t,pmu_z.val[z,:],label='z={}'.format(z))
        ax[0].set_xlabel('$\mu$')
        ax[0].legend()
        ax[1].scatter(range(len(F_ll)),F_ll,marker='.')
        ax[1].set_ylim([0,1000])
        ax[2].set_ylim([0,.05])
        ax[2].scatter(range(len(F_px)),F_px,marker='.')
        ax[0].set_title('$p(\mu_z)$')
        ax[1].set_title('$LL_D[p_{real}]-LL_D[p]$')
        ax[2].set_title('$||p(X)-p_{real}(X)||$')
        fig.subplots_adjust(top=0.82)
        plt.show()

    print F_ll[-1],F_px[-1]
In [231]:
datapoints = 1000
batchsize = 5
epochs = 1

batch_nobetas(D,px_mu,datapoints,batchsize,epochs,pu='post',opt='sample')
batch_nobetas(D,px_mu,datapoints,batchsize,epochs,pu='post',opt='MAP')

batch_nobetas(D,px_mu,datapoints,batchsize,epochs,pu='marg',opt='sample')
batch_nobetas(D,px_mu,datapoints,batchsize,epochs,pu='marg',opt='MAP')
32.7874737953 0.0292475910055
29.9672045679 0.0295199792057
-0.517241016706 0.00194791610591
2.88660627829 0.0153687913501