In [1]:
#from IPython.core.display import HTML
#import urllib
#response = urllib.urlopen("""
#    https://gist.githubusercontent.com/sgttwld/c060b18a9d6ce7c3a10e3c6dce2c0d3a/raw
#""")
#HTML("<style type='text/css'>"+css+"</style>")

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats
sns.set()
%config InlineBackend.figure_format = 'retina'
import _lib.pr_func as pr
import _lib.BA_class as BA

def get_ll(px,D):
ll = 0
for d in D:
ll += np.log(px.val[d])
return ll

def validate_ll(pmu,D,best_ll):
px = pr.sum(['mu'],pmu*px_mu)
ll = get_ll(px,D)
return best_ll - ll

def validate_px(pmu,px_real):
px = pr.sum(['mu'],pmu*px_mu)
return np.linalg.norm(px.val-px_real.val)


# Bayesian inference for mixture distributions¶

Sebastian Gottwald, 2018

## Problem statement¶

Let $X$ be an observable that is distributed according to a mixture of components with known distributions $p(x|\mu)$. The goal is to find the parameters $\mu$ from which the data was generated.

The mixture can be written as $$p(x) = \sum_\mu \, p(\mu) \, p(x|\mu)$$ and the goal would be to find $p(\mu)$. However, if the number of mixture components $M$ is known we can split up $p(\mu)$ into $\{p(\mu|z)\}_{z=1}^M$:

$$\tag{\ast} p(x) = \sum_{z=1}^M p(z) \sum_\mu p(\mu|z) \, p(x|\mu)$$

For simplicity, we assume equal mixture coefficients, i.e. $\rho_z:= p(z)=1/M$ for all $z$.

## Examplary data¶

Let $M=4$, $p(x|\mu) = \mathcal{N}(x\,|\,\mu,\sigma^2)$ with $\sigma^2=25$, and $p_{real}(\mu) = \frac{1}{4}(\delta_{\mu,10} + \delta_{\mu,20} + \delta_{\mu,40} + \delta_{\mu,50})$

In [229]:
N,M,K = 60,4,60

sigma = 25.0
datasize = 1000

pr.set_dims([('x',N),('z',M),('mu',K)])
rho = np.ones(M)/M
pz = pr.func('f(z)',val=rho)

def f(K,mean,sigma):
k = np.arange(0,K)
return np.exp(-(k-mean)**2/sigma)

def g(N,K,sigma,offset):
llh = np.zeros([N,K])
for i in range(0,N):
llh[i,:] = f(K,i-offset,sigma)
return llh

def sample(px_mat):
return np.random.choice(range(0,len(px_mat)),p=px_mat)

# likelihood
px_mu = pr.func('f(x,mu)', val=g(N,K,sigma=sigma,offset=0.0)).normalize(['x'])

# real p(mu)
if M == 2:
pmu_real = pr.func('f(mu)', val=f(K,mean=10,sigma=.1)+f(K,mean=30,sigma=.1)).normalize()
elif M == 3:
pmu_real = pr.func('f(mu)', val=f(K,mean=10,sigma=.1)+f(K,mean=30,sigma=.1)+f(K,mean=40,sigma=.1)).normalize()
elif M==4:
means = [10,20,40,50]
mat = 0
for mean in means:
mat += f(K,mean=mean,sigma=.1)
pmu_real = pr.func('f(mu)', val=mat).normalize()
elif M==6:
means = [15,25,35,65,75,85]
mat = 0
for mean in means:
mat += f(K,mean=mean,sigma=.1)
pmu_real = pr.func('f(mu)', val=mat).normalize()

px_real = pr.sum(['mu'],px_mu*pmu_real)

# create data
D = [sample(px_mu.val[:,sample(pmu_real.val)]) for d in range(0,datasize)]

best_ll = get_ll(px_real,D)

freq = np.zeros(N)
for d in D:
for i in range(0,N):
if d == i: freq[i] += 1

PY = freq/np.sum(freq)

with plt.rc_context({"figure.figsize": (18,5)}):
fig,ax = plt.subplots(1,3)
ax[0].set_title('$p(x|\mu)$')
ax[0].pcolor(px_mu.val)
ax[0].set_xlabel('$\mu$')
ax[0].set_ylabel('$x$')
ax[1].set_title('$p_{real}(\mu)$')
ax[1].bar(range(K),pmu_real.val)
ax[1].set_xlabel('$\mu$')
ax[2].set_title('frequencies of $d=x$ in $D$, $|D|={}$'.format(datasize))
ax[2].bar(range(0,len(PY)),PY)
ax[2].set_xlabel('$x$')
plt.show()


## 1. Approximating $p(\mu)=\sum_x p(\mu|x) p(x)$¶

Here, we generate $p(\mu)$ from $\{p(\mu|d)\}_{d\in D}$ iteratively, as follows:

1. initialize $t=0$ and $p^{(0)}(\mu) = 1/K$ for all $\mu$ (uniform)
2. For each $d\in D$ calculate $p^{(t)}(\mu|d) = \frac{p(d|\mu)\, p^{(t)}(\mu)}{\sum_{\mu} p(d|\mu) \, p^{(t)}(\mu)}$
3. Effectively approximate $p(\mu) = \sum_{x} p(\mu|x) p(x)$ by $\sum_{x} p(\mu|x)\, f_D(x)$, where $f_D(x)$ is the frequency of $x$ in $D$: $$p^{(t+1)}(\mu) = \frac{1}{|D|} \sum_{d\in D} p^{(t)}(\mu|d)$$
4. Set $t=t+1$ and continue with 2.
In [230]:
datasize = 1000
batchsize = 100
epochs = 10

# init
pmu = pr.func('f(mu)', val='unif').normalize()
F_ll,F_px = [],[]

# inference to find marginal
for i in range(0,epochs):
for n in range(int(datapoints/batchsize)):
D0 = D[n*batchsize:(n+1)*batchsize]
pmu_ds = [(px_mu.eval('x',d)*pmu).normalize().val for d in D0]
pmu_marg = 0
for pmu_d in pmu_ds:
pmu_marg += pmu_d
pmu.val = pmu_marg/len(pmu_ds)
F_ll.append(validate_ll(pmu,D,best_ll))
F_px.append(validate_px(pmu,px_real))

with plt.rc_context({"figure.figsize": (18,4)}):
fig,ax = plt.subplots(1,3)
fig.suptitle('datasize: {}, batchsize: {}, epochs: {}'.format(datapoints,batchsize,epochs))
ax[0].bar(range(0,len(pmu.val)),pmu.val)
ax[0].set_xlabel('$\mu$')
ax[0].set_title('$p(\mu)$')
ax[1].scatter(range(len(F_ll)),F_ll,marker='.')
#ax[1].set_ylim([0,1000])
ax[1].set_title('$LL_D[p_{real}]-LL_D[p]$')
ax[2].scatter(range(len(F_px)),F_px,marker='.')
#ax[2].set_ylim([0,.05])
ax[2].set_title('$||p(X)-p_{real}(X)||$')
plt.show()

fitness_ll = validate_ll(pmu,D,best_ll)
fitness_px = validate_px(pmu,px_real)

print fitness_ll,fitness_px

9.98677666392 0.0160731370158


## 2. EM algorithm¶

Here, we assume $p(\mu|z) = \delta_{\mu_z,\mu}$, where $\{\mu_z\}_{z=1}^M$ are the unknown parameters we want to find, so that $(\ast)$ becomes $$p(x) = \sum_{z=1}^M\rho_z \sum_\mu \delta_{\mu_z,\mu}\, p(x|\mu) = \sum_{z=1}^M \rho_z \, p(x|\mu_z)$$

The EM algorithm for Gaussian mixtures is as follows:

1. Initialize $t=0$ and $\{\mu_z^{(0)}\}_{z=1}^M$ randomly

2. (E-step) For each $d\in D$ calculate $$a_z(d) := p(z|d,\mu_1^{(t)},\dots,\mu_M^{(t)}) = \frac{\rho_z\, p(d|\mu^{(t)}_z)}{\sum_{j=1}^M \rho_j \, p(d|\mu^{(t)}_j)} \quad \forall z\in \{1,\dots, M\}$$

3. (M-step) Determine $\{\mu^{(t+1)}_z\}_{z=1}^M$ from $\{a_z\}_{z=1}^M$ as follows: $$\mu_z^{(t+1)} = \frac{\sum_{d\in D} a_z(d) \cdot d}{\sum_{d'\in D} a_z(d')} \quad \forall z\in \{1,\dots, M\}$$

4. Set $t=t+1$ and continue with 2.

Note: One can do batches by running one EM-cycle per batch $D_i\subset D$ for $i=1,\dots,B$, where $B$ is the number of batches.

In [244]:
datapoints = 1000
batchsize = 1000
epochs = 30

# fold
pr.set_dims([('x',N),('z',M),('mu',K)])

def EM_step(mus,pxgmu,D0):
mus = [int(mu) for mu in mus]
pxgmuz = [pxgmu[:,mu] for mu in mus]
# E-step
pzgd = []
for i in range(0,len(D0)):
d = D0[i]
Z = 0
v = np.zeros(M)
for z in range(0,M):
v[z] = pxgmuz[z][d]*rho[z]
pzgd.append(v/np.sum(v))
# M-step
pzgd = np.array(pzgd)
mus = np.einsum('ij,i,j->j',pzgd,D0, 1.0/np.einsum('ij->j',pzgd))
return mus

def get_pmu_from_mus(mus,rho):
pmu = np.zeros(K)
for j in range(0,K):
for z in range(0,M):
if j == int(mus[z]):
pmu[j] = rho[z]
return pmu

# init mu1,...
mus = np.random.choice(range(K),M)
F_ll, F_px = [],[]

# EM steps:
for i in range(0,epochs):
for n in range(int(datapoints/batchsize)):
D0 = D[n*batchsize:(n+1)*batchsize]
mus = EM_step(mus,px_mu.val,D0)

# validate
pmu = pr.func('f(mu)',val=get_pmu_from_mus(mus,pz.val))
F_ll.append(validate_ll(pmu,D,best_ll))
F_px.append(validate_px(pmu,px_real))

with plt.rc_context({"figure.figsize": (18,4)}):
s = 'EM-algorithm ($p(\mu)=\sum_z\, \\rho_z\, \delta_{\mu_z,\mu}$)'
fig,ax = plt.subplots(1,3)
fig.suptitle(s + ', datasize: {}, batchsize: {}, epochs: {}'.format(datapoints,batchsize,epochs))
ax[0].bar(range(0,len(pmu.val)),pmu.val)
ax[0].set_xlabel('$\mu$')
ax[1].scatter(range(len(F_ll)),F_ll,marker='.')
ax[1].set_ylim([0,1000])
ax[2].set_ylim([0,.05])
ax[2].scatter(range(len(F_px)),F_px,marker='.')
ax[0].set_title('$p(\mu)$')
ax[1].set_title('$LL_D[p_{real}]-LL_D[p]$')
ax[2].set_title('$||p(X)-p_{real}(X)||$')
plt.show()

fitness_ll = validate_ll(pmu,D,best_ll)
fitness_px = validate_px(pmu,px_real)

print fitness_ll,fitness_px

13.5511405589 0.0207590704267


## 3. Bayesian inference with conjugate priors and sampling¶

We are inferring $p(\mu|z) = p(\mu_z)$ in $(\ast)$ by making use of conjugate priors and the ability to sample from parametrized and discrete distributions. Since the likelihood $p(x|\mu)=\mathcal N(x|\mu,\phi^{-1})$ is Gaussian (in our example, $\phi=\frac{1}{20}$), we pick a Gaussian prior over $\mu_j$ for $j=1,\dots,M$:

$$p(\mu_j) = \mathcal N\big( m_j , \tfrac{1}{\phi \, \alpha_j}\big)(\mu_j)$$

and the goal is to determine $\{m_j\}_{j=1}^M$ as well as $\{\alpha_j\}_{j=1}^M$.

The algorithm works as follows:

• Initialize $t=0$ and $m^{(0)}_j$, $\alpha^{(0)}_j$ for all $j=1,\dots, M$

• For $D_{batch}\subset D$:

1. Sample $\mu_j$ from $p^{(t)}(\mu_j) = \mathcal N(m_j^{(t)}, \tfrac{1}{\phi \, \alpha^{(t)}_j})(\mu_j)$ for all $j=1,\dots, M$

2. Initialize $\{D_z\}_{z=0}^M$ as a collection of empty sets so that $D_z$ collects the $d$s that are proposed to correspond to $z$.

3. For all $d\in D_{batch}$ sample $z_d\sim p(z|d,\mu_1,\dots, \mu_M) = \frac{\rho_z \, p(d|\mu_z)}{\sum_j \rho_j\, p(d|\mu_j)}$ and add $d$ to $D_{z_d}$.

4. From $\{D_z\}_{z=1}^M$ determine occupation numbers $\{n_z\}_{z=0}^M$ and sums $\{S_z\}_{z=0}^M$ by

$$n_z = |D_z| , \quad S_z = \sum_{d\in D_z} d \qquad \forall z=1,\dots,M$$

5. Update $p^{(t+1)}(\mu_j)\leftarrow p(\mu_j|D_j)$ for all $j=1,\dots,M$, where the posterior for $\mu_j$ is

$$p(\mu_j|D_j) = \mathcal N\big(m_j^{(t+1)}, \tfrac{1}{\phi\, \alpha_j^{(t+1)}}\big)(\mu_j)$$ where $$m_j^{(t+1)} = \frac{\alpha_j^{(t)} m_j^{(t)} + S_j}{\alpha_j^{(t)} + n_j} \ , \qquad \alpha_j^{(t+1)} = \alpha_j^{(t)} + n_j$$

6. Set $t=t+1$ and continue with $1$.

Question: How is this related to Gibbs-Sampling, which would be

1. Initialize $\mu,z$
2. Sample $x \sim p(x|\mu,z)$
3. Sample $z \sim p(z|x,\mu)$
4. Sample $\mu \sim p(\mu|x,z)$
5. Go back to 2.

Possible Answer: Let $\mathbf x = (x_1,\dots,x_{|D_{batch}|}):= D_{batch}$ and do

1. Initialize $\mu_1,\dots,\mu_M$
2. Sample $\mathbf z=(z_1,...,z_{|D_{batch}|})$ via $z_d \sim p(z|d,\mu_1,\dots,\mu_M)$ for all $d\in D_{batch}$.
3. Sample $\mu_1,\dots,\mu_M$ via $p(\mu_j|\mathbf z,\mathbf x) = p(\mu_j|D_j)$ (see above)
In [245]:
datapoints = 1000
batchsize = 10
epochs = 10

pr.set_dims([('x',N),('z',M),('mu',K)])

def f_px_mu(mu,x):
return scipy.stats.norm.pdf(x,loc=mu,scale=np.sqrt(20))

def delta(i,j):
if i==j: return 1
else: return 0

def normal(t,m,sigmasqu):
y = scipy.stats.norm.pdf(t,loc=m,scale=np.sqrt(sigmasqu))
return y/np.sum(y)

phi = 1.0/sigma

# initialize
m = np.random.uniform(K/4.0,K*3.0/4.0,M)
alpha = np.ones(M)/(K*phi)
F_ll,F_px = [],[]

for l in range(epochs):
for n in range(int(datapoints/batchsize)):
D0 = D[n*batchsize:(n+1)*batchsize]
Dparts = [[] for i in range(M)]
for d in D0:
mus = np.random.normal(m,1.0/(alpha*phi))
#2 # calculate p(z|d,mu1,mu2,...)
pzgd = rho*f_px_mu(mus,d)/np.sum(rho*f_px_mu(mus,d))
z = np.random.choice(range(M),p=pzgd) # proposed z for the given d
#MAP opt: z = np.argmax(pzgd)
Dparts[z].append(d)
## calc stats from Dparts
occ = np.array([len(Dpart) for Dpart in Dparts])
sum_delta = np.array([np.sum(Dpart) for Dpart in Dparts])
## update pars after batch:
m = (alpha*m + sum_delta)/(alpha+occ)
alpha += occ
# for validation:
pmu_z_mat = np.array([normal(np.arange(0,K),m[z],1.0/(phi*alpha[z])) for z in range(0,M)])
pmu_z = pr.func('f(mu,z)',val=pmu_z_mat)
F_ll.append(validate_ll(pr.sum(['z'],pmu_z*pz),D,best_ll))
F_px.append(validate_px(pr.sum(['z'],pmu_z*pz),px_real))

# plot

with plt.rc_context({"figure.figsize": (18,5)}):
fig,ax = plt.subplots(1,3)
pre = 'Bayesian inference with conjugate priors ($p(\mu|z) = p(\mu_z)$), '
fig.suptitle(pre + 'datasize: {}, batchsize: {}, epochs: {}'.format(datapoints,batchsize,epochs))
for z in range(0,M):
ax[0].bar(t,pmu_z.val[z],label='z={}'.format(z))
ax[0].set_xlabel('$\mu$')
ax[0].legend()
ax[1].scatter(range(len(F_ll)),F_ll,marker='.')
ax[1].set_ylim([0,1000])
ax[2].set_ylim([0,.05])
ax[2].scatter(range(len(F_px)),F_px,marker='.')
ax[0].set_title('$p(\mu_z)$')
ax[1].set_title('$LL_D[p_{real}]-LL_D[p]$')
ax[2].set_title('$||p(X)-p_{real}(X)||$')

plt.show()

fitness_ll = validate_ll(pr.sum(['z'],pmu_z*pz),D,best_ll)
fitness_px = validate_px(pr.sum(['z'],pmu_z*pz),px_real)

print fitness_ll,fitness_px

32.5933555726 0.0272345006644


## 4. "2-step" Bayesian inference (no conjugate priors, but discrete)¶

Here, we infer $p(\mu|z)$ in $(\ast)$ without conjugate priors, directly from the two-step inference process $x \rightarrow z \rightarrow \mu$:

\begin{eqnarray} \tag{$i$} p(z|x) & = & \frac{p(z) \,p(x|z)}{\sum_{z'} p(z') \, p(x|z')} = \frac{p(z) \, \sum_\mu p(\mu|z) \, p(x|\mu)}{\sum_{z'} p(z') \, \sum_{\mu'} p(\mu'|z) \, p(x|\mu')} \\[10px] \tag{$ii$} p(\mu|z,x) & = & \frac{p(x,\mu|z)}{p(x|z)} = \frac{p(\mu|z) p(x|\mu)}{\sum_{\mu'} p(\mu'|z) p(x|\mu')} \end{eqnarray}

where we used that by assumption $p(x|\mu,z) = p(x|\mu)$.

### Prior update¶

There are two obvious possiblities to update the prior $p(\mu|z)$:

(1) Posterior: $p(\mu|z) \leftarrow p(\mu|z,d)$ (assumes $d$ belongs to the particular $z$, enforces specialization)

(2) Marginal: $p(\mu|z) \leftarrow \sum_x p(\mu|z,x) p(x|z)$ (only works with batches or full $D$)

### Assigning the component¶

Here, we compare two different ways to determine the component $z$ a given datapoint $d$ belongs to:

Sampling: $z\sim p(z|d)$

MAP: $z = \max_{z'}p(z'|d)$

### (a) Most simplest version: posterior prior update, no batches¶

1. Initialize $t=0$ and $p^{(0)}(\mu|z)$
2. Selector step $(i)$: Given $d\in D$, calculate $p(z|d)=\frac{\rho_z \sum_\mu p^{(t)}(\mu|z) p(d|\mu)}{\sum_z \dots}$ and sample $z\sim p(z|d)$ or pick $z = \mathrm{argmax}_{z'}p(z'|d)$.
3. Actor step $(ii)$: Determine $p(\mu|d,z) = \frac{p^{(t)}(\mu|z) \, p(d|\mu)}{\sum_\mu \dots}$ and set $p^{(t+1)}(\mu|z) = p(\mu|d,z)$
4. Set $t=t+1$ and continue with 2.
In [234]:
def postupdate_nobatch_nobetas(D,px_mu,datapoints,epochs,opt='sample'):
pz = pr.func('f(z)',val='unif').normalize()

#init
pmu_z = pr.func('f(z,mu)',val='unif').normalize(['mu'])
F_ll, F_px = [],[]

# inference
for d in D[:datapoints]:
pz_d = (pz*pr.sum(['mu'],pmu_z*px_mu.eval('x',d))).normalize()
if opt == 'sample':
opt_str = 'component-sampling ($z\sim p(z|d)$)'
z = sample(pz_d.val)
elif opt == 'MAP':
opt_str = 'component-MAP ($z = argmax_{z} \,p(z|d)$)'
z = np.argmax(pz_d.val)
pmu_z.val[z,:] = px_mu.eval('x',d).val*pmu_z.val[z,:]
pmu_z = pmu_z.normalize(['mu'])
# evaluate
F_ll.append(validate_ll(pr.sum(['z'],pmu_z*pz),D,best_ll))
F_px.append(validate_px(pr.sum(['z'],pmu_z*pz),px_real))

t = range(K)
with plt.rc_context({"figure.figsize": (18,5)}):

fig,ax = plt.subplots(1,3)
fig.suptitle(opt_str +', datasize: {}, epochs: {}'.format(datapoints,epochs))
for z in range(M):
ax[0].bar(t,pmu_z.val[z,:],label='z={}'.format(z))
ax[0].set_xlabel('$\mu$')
ax[0].legend()
ax[1].scatter(range(len(F_ll)),F_ll,marker='.')
ax[1].set_ylim([0,1000])
ax[2].set_ylim([0,.05])
ax[2].scatter(range(len(F_px)),F_px,marker='.')
ax[0].set_title('$p(\mu_z)$')
ax[1].set_title('$LL_D[p_{real}]-LL_D[p]$')
ax[2].set_title('$||p(X)-p_{real}(X)||$')
plt.show()

print F_ll[-1],F_px[-1]

In [236]:
datapoints = 1000
epochs = 1

postupdate_nobatch_nobetas(D,px_mu,datapoints,epochs,opt='sample')
postupdate_nobatch_nobetas(D,px_mu,datapoints,epochs,opt='MAP')

-0.0639424400006 0.000329131483885

2.50033294131 0.0152364648336


### (b) With batches¶

Let $D_{batch}\subset D$ and let $D_z\subset D_{batch}$ denote a proposal of datapoints that belong to component $z$, obtained in a given iteration by sampling $z_d\sim p(z|d)$ and adding $d$ to $D_{z_d}$ for each $d\in D_{batch}$. With this notation, the two possible prior updates, $(1)$ and $(2)$, are as follows:

(1) Update prior by posterior: This generalizes the non-batch version $(a)$ directly by replacing the update rule $p(\mu|z)\leftarrow p(\mu|z,d)$ by

$$\tag{1} p(\mu|z)\leftarrow p(\mu|z,D_z) = \frac{p(\mu|z) \, p(D_z|\mu)}{\sum_{\mu'} p(\mu'|z) \, p(D_z|\mu')} = \frac{p(\mu|z) \, \prod_{d\in D_z} p(d|\mu)}{\sum_{\mu'} p(\mu'|z) \, \prod_{d\in D_z} p(d|\mu')} \qquad \forall z=1,\dots, M$$

(2) Update prior by marginal: Here, we approximate the expectation over $p(x|z)$ in the marginal $p(\mu|z) = \sum_x p(\mu|z,x) p(x|z)$ by the expectation w.r.t. the frequency

$$f_{D_z}(x) = \frac{1}{|D_z|} \, \big|\big\{d\in D_z|d=x\big\}\big|$$

so that the prior update becomes

$$\tag{2} p(\mu|z) \leftarrow \sum_x p(\mu|z,x) p(x|z) \approx \sum_x p(\mu|z,x) f_{D_z}(x) = \frac{1}{|D_z|} \sum_{d\in D_z} p(\mu|z,d) \qquad \forall z=1,\dots, M$$

Here is the algorithm:

• Initialize $t=0$ and $p^{(0)}(\mu|z)$

• For $D_{batch}\subset D$:

1. Initialize $\{D_z\}_{z=0}^M$ as a collection of empty sets so that $D_z$ collects the $d$s that are proposed to correspond to $z$.

2. (Selector step) For all $d\in D_{batch}$ sample $z_d\sim p(z|d)=\frac{\rho_z \sum_\mu p^{(t)}(\mu|z) p(d|\mu)}{\sum_z \dots}$ or pick $z = \mathrm{argmax}_{z'}p(z'|d)$ and add $d$ to $D_{z_d}$.

3. Depending on the chosen prior update, we do either (1) or (2) from above:

(1) (Actor step with posterior prior update) For all $z=0,...,M$, determine

$$p(\mu|D_z,z) = \frac{p^{(t)}(\mu|z) \, p(D_z|\mu)}{\sum_{\mu'} p^{(t)}(\mu'|z) \, p(D_z|\mu')} = \frac{p^{(t)}(\mu|z) \, \prod_{d\in D_z} p(d|\mu)}{\sum_{\mu'} p^{(t)}(\mu'|z) \, \prod_{d\in D_z} p(d|\mu')}$$

and set $p^{(t+1)}(\mu|z) = p(\mu|D_z,z)$.

(2) (Actor step with marginal prior update) For all $z=0,...,M$, determine $$p(\mu|z,d) = \frac{p^{(t)}(\mu|z) p(d|\mu)}{\sum_{\mu'} p^{(t)}(\mu'|z) p(d|\mu')} \qquad \forall d\in D_{z}$$ and set $p^{(t+1)}(\mu|z) = \frac{1}{|D_z|} \sum_{d\in D_z} p(\mu|d,z)$.

4. Set $t=t+1$ and continue with 1 and the next batch.

In [201]:
def batch_nobetas(D,px_mu,datapoints,batchsize,epochs,pu='post',opt='sample'):

pr.set_dims([('x',N),('z',M),('mu',K)])

pz = pr.func('f(z)',val='unif').normalize()

#init
pmu_z = pr.func('f(z,mu)',val='unif').normalize(['mu'])
F_ll,F_px = [],[]

# inference
for r in range(epochs):
for n in range(int(datapoints/batchsize)):
D0 = D[n*batchsize:(n+1)*batchsize]
Dparts = [[] for j in range(M)]
for d in D0:
pz_d = (pz*pr.sum(['mu'],pmu_z*px_mu.eval('x',d))).normalize()
if opt == 'sample':
opt_str = 'component-sampling ($z\sim p(z|d)$)'
z = sample(pz_d.val)
elif opt == 'MAP':
opt_str = 'component-MAP ($z = argmax_{z} \,p(z|d)$)'
z = np.argmax(pz_d.val)
Dparts[z].append(d)
for z in range(M):
Dpart = Dparts[z]
if len(Dpart) > 0:
if pu == 'post':
amu = np.einsum('ij->j',np.array([np.log(px_mu.eval('x',d).val) for d in Dpart]))
b = np.max(amu)
pmu_z.val[z,:] = pmu_z.val[z,:]*np.exp(amu-b)/np.sum(pmu_z.val[z,:]*np.exp(amu-b))
elif pu == 'marg':
pmu_zd_temp = [px_mu.eval('x',d).val*pmu_z.val[z,:] for d in Dpart] # p(d|mu)*p(mu|z)
pmu_zd = np.array([pmu_zdi/np.sum(pmu_zdi) for pmu_zdi in pmu_zd_temp]) #p(d|mu)*p(mu|z)/sum(...)
pmu_z.val[z,:] = np.einsum('ik->k',pmu_zd)/len(Dpart) #  1/|Dpart| sum_{d in Dpart} p(mu|z,d)
pmu_z = pmu_z.normalize(['mu'])
F_ll.append(validate_ll(pr.sum(['z'],pmu_z*pz),D,best_ll))
F_px.append(validate_px(pr.sum(['z'],pmu_z*pz),px_real))

t = range(K)
with plt.rc_context({"figure.figsize": (18,5)}):

fig,ax = plt.subplots(1,3)
if pu =='post':
pre = 'Posterior prior update, '
elif pu =='marg':
pre = 'Marginal prior update, '
fig.suptitle(pre+opt_str+', datasize: {}, batchsize: {}, epochs: {}'.format(datapoints,batchsize,epochs))
for z in range(M):
ax[0].bar(t,pmu_z.val[z,:],label='z={}'.format(z))
ax[0].set_xlabel('$\mu$')
ax[0].legend()
ax[1].scatter(range(len(F_ll)),F_ll,marker='.')
ax[1].set_ylim([0,1000])
ax[2].set_ylim([0,.05])
ax[2].scatter(range(len(F_px)),F_px,marker='.')
ax[0].set_title('$p(\mu_z)$')
ax[1].set_title('$LL_D[p_{real}]-LL_D[p]$')
ax[2].set_title('$||p(X)-p_{real}(X)||$')
plt.show()

print F_ll[-1],F_px[-1]

In [231]:
datapoints = 1000
batchsize = 5
epochs = 1

batch_nobetas(D,px_mu,datapoints,batchsize,epochs,pu='post',opt='sample')
batch_nobetas(D,px_mu,datapoints,batchsize,epochs,pu='post',opt='MAP')

batch_nobetas(D,px_mu,datapoints,batchsize,epochs,pu='marg',opt='sample')
batch_nobetas(D,px_mu,datapoints,batchsize,epochs,pu='marg',opt='MAP')

32.7874737953 0.0292475910055

29.9672045679 0.0295199792057

-0.517241016706 0.00194791610591

2.88660627829 0.0153687913501