Dirichlet process mixture models

Bayesian mixture models introduced how to infer the posterior of the parameters of a mixture model with a fixed number of components $K$. We can either find $K$ using model selection, i.e. with AIC, BIC, WAIC, etc., or try to automatically infer this number. Nonparametric mixture models do exactly this.

Here we implement a nonparametric Bayesian mixture model using Gibbs sampling. We use a Chinese restaurant process prior and stick-breaking construction to sample from a Dirichlet process (see for instance Nils Hjort's Bayesian Nonparametrics, Peter Orbanz' lecture notes), Kevin Murphy's book and last but not least Herman Kamper's notes.

We'll implement the Gibbs sampler using the CRP ourselves, since (I think) Stan doesn't allow us to do this and then use the stick-breaking construction with Stan. That is technically not possible though, so we use a small hack (a truncated DP).

As usual I do not take warranty for the correctness or completeness of this document.

I'll use R, cause it's the bestest!

In [1]:
options(repr.fig.width=4, repr.plot.height=3)
In [2]:
suppressMessages(library("e1071"))
suppressMessages(library("mvtnorm"))
suppressMessages(library("dplyr"))
suppressMessages(library("ggplot2"))
suppressMessages(library("MCMCpack"))
suppressMessages(library("bayesplot"))
suppressMessages(library("rlang"))
suppressMessages(library("tsne"))
set.seed(23)
In [3]:
suppressMessages(library(rstan))
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

In Bayesian mixture models we used following hierarchical form to describe a mixture model:

\begin{align*} \boldsymbol \theta_k & \sim \mathcal{G}_0\\ \boldsymbol \pi & \sim \text{Dirichlet}(\boldsymbol \alpha_0)\\ z_i & \sim \text{Discrete}(\boldsymbol \pi)\\ \mathbf{x}_i \mid z_i = k & \sim {P}(\boldsymbol \theta_k) \end{align*}

where $\mathcal{G}_0$ is some base distribution for the model parameters.

The DP on contrast, as any BNP model, puts priors on structures that accomodate infinite sizes. The resulting posteriors give a distribution on structures that grow with new observations. A mixture model using an possibly infinite number of components could look like this:

\begin{align*} \mathcal{G} & \sim \mathcal{DP}(\alpha, \mathcal{G}_0)\\ \boldsymbol \theta_i & \sim \mathcal{G}\\ \mathbf{x}_i& \sim {P}(\boldsymbol \theta_i) \end{align*}

where $\mathcal{G}_0$ is the same base measure as above and $\mathcal{G}$ is a sample from the DP, i.e. also a random measure.

The Chinese restaurant process

One way, and possibly the easiest, to implement a DPMM is using a Chinese restaurant process (CRP) which is a distribution over partitions.

Data generating process

The hierarchical model using a CRP is:

\begin{align*} \boldsymbol \theta_k & \sim \mathcal{G}_0 \\ z_i \mid \mathbf{z}_{1:i-1} & \sim \text{CRP} \\ \mathbf{x}_i & \sim P(\boldsymbol \theta_{z_i}) \end{align*}

where $\text{CRP}$ is a prior on possible infinitely many classes. Specifically the CRP is defined as:

\begin{align*} P(z_i = k \mid \mathbf{z}_{-i}) = \left\{ \begin{array}{ll} \frac{N_k}{N - 1 + \alpha}\\ \frac{\alpha}{N - 1 + \alpha}\\ \end{array} \right. \end{align*}

where $N_k$ is the number of customers at table $k$ and $\alpha$ some hyperparameter.

For the variables of interest, $\boldsymbol \theta_k$ and $\boldsymbol z$ the posterior is:

\begin{align*} P(\boldsymbol \theta, \boldsymbol z \mid \mathbf{X}) \propto P(\mathbf{X} \mid \boldsymbol \theta, \boldsymbol z ) P(\boldsymbol \theta) P ( \boldsymbol z ) \end{align*}

Using a Gibbs sampler, we iterate over the following two steps:

1) sample $z_i \sim P(z_i \mid \mathbf{z}_{-i}, \mathbf{X}, \boldsymbol \theta) \propto P(z_i \mid \mathbf{z}_{-i}) P(\mathbf{x}_i \mid \boldsymbol \theta_{z_i}, \mathbf{X}_{-i}, \mathbf{z})$

2) sample $\boldsymbol \theta_k \sim P(\boldsymbol \theta_k \mid \mathbf{z}, \mathbf{X})$

So we alternate sampling assignments of data to classes and sampling the parameters of the data distribution given the class assignments. The major difference here compared to the finite case is the way of sampling $z_i$ which we do using the CRP in the infinite case. The CRP itself is defined by $ P(z_i \mid \mathbf{z}_{-i}) $, so replacing this by a usual finite sample would give us a finite mixture. Evaluation of the likelihoods in the first step is fairly straightforward as we will see. Updating the model parameters in the second step is conditional on every class, an by that also not too hard to do.

Stick-breaking construction

With the CRP with put a prior distribution on the possibly infinite number of class assignments. An alternative approach is to use stick-breaking construction. The advantage here is that we could use Stan using a truncated DP, thus we don't need to implement the sampler ourselves.

Data generating process

If we, instead of putting a CRP prior on the latent labels, put a prior on the possibly infinite sequence of mixing weights $\boldsymbol \pi$ we arrive at the stick-breaking construction. The hierarchical model now looks like this:

\begin{align*} \nu_k &\sim \text{Beta}(1, \alpha) \\ \pi_k & = \nu_k \prod_{j=1}^{k-1} (1 - \nu_j) \\ \boldsymbol \theta_k & \sim G_0 \\ \mathbf{x}_i & \sim \sum_k \pi_k P(\boldsymbol \theta_k) \end{align*}

where $N_k$ is the number of customers at table $k$ and $\alpha$ some hyperparameter. The distribution of the mixing weights is sometimes denoted as

$$ \boldsymbol \pi \sim \text{GEM}(\alpha) $$

Gaussian DPMM

In the following section, we derive a Gaussian Dirichlet process mixture using the CRP with a Gibbs sampler and the stick-breaking construction using Stan.

CRP

In the Gaussian case the hierarchical model using the CRP has the following form:

\begin{align*} \boldsymbol \Sigma_k & \sim \mathcal{IW}\\ \boldsymbol \mu_k & \sim \mathcal{N}(\boldsymbol \mu_0, \boldsymbol \Sigma_0) \\ z_i \mid z_{1:i-1} & \sim \text{CRP} \\ \mathbf{x}_i & \sim \mathcal{N}(\boldsymbol \mu_{z_i}, \boldsymbol \Sigma_{z_i}) \end{align*}

Let's derive the Gibbs sampler for a infinite Gaussian mixture using the CRP.

First we set data $\mathbf{X}$ some constants. We create a very simple data set to avoid problems with identifiability and label switching. For a treatment of the topic see Michael Betancourt's case study. $n$ is the number of samples, $p$ is the dimensionality of the Gaussian, $\alpha$ is the Dirichlet concentration.

In [4]:
n <- 100
p <- 2
alpha <- .5

Latent class assignments (Z), the current table index and the number of customers per table:

In [5]:
Z <- integer(n)
X <- matrix(0, n, p)
curr.tab <- 0
tables <- c()

Parameters of the Gaussians:

In [6]:
sigma <- .1
mus <- NULL

Then we create a random assignment of customers to tables with probability $P(z_i \mid Z_{-i})$, i.e. we use the CRP to put data into classes. Note that we don't know the number of classes that comes out!

In [7]:
for (i in seq(n))
{   
    probs <- c(tables / (i - 1 + alpha), alpha / (i - 1 + alpha))
    table <- rdiscrete(1, probs)
    if (table > curr.tab) 
    {
        curr.tab <- curr.tab + 1
        tables <- c(tables, 0)
        mu <- mvtnorm::rmvnorm(1, c(0, 0), 10 * diag(p))
        mus <- rbind(mus, mu)
    }
    Z[i] <- table
    X[i,] <- mvtnorm::rmvnorm(1, mus[Z[i], ], sigma * diag(p))
    tables[table] <-  tables[table] + 1
}

Let's see how many clusters and how many data points per clusters we have.

In [8]:
data.frame(table(Z))%>%
    ggplot() +
    geom_col(aes(Z, Freq), width=.5) +
    theme_minimal()
In [9]:
data.frame(X=X, Z=as.factor(Z)) %>%
    ggplot() +
    geom_point(aes(X.1, X.2, col=Z)) +
    theme_minimal()

Posterior inference using Gibbs sampling

Let's infer the posteriors.

We randomly initialize the cluster assignments and set all customers to table 1. Hyperparameter $\alpha$ controls the probability of opening a new table.

In [10]:
# initialization of the cluster assignments
K <- 1
zs <- rep(K, n)
alpha <- 5
tables <- n

Define the priors of the model. We set the covariances to be fixed.

In [11]:
mu.prior <- matrix(c(0, 0), ncol = 2)
sigma.prior <- diag(p)
q.prior <- solve(sigma.prior)

Base distribution $\mathcal{G}_0$:

In [12]:
sigma0 <- diag(p)
prec0 <- solve(sigma0)
mu0 <- rep(0, p)

To infer the posterior we would use the Gibbs sampler described above. Here, I am only interested in the most likely assignment, i.e. the map of $Z$.

In [13]:
for (iter in seq(100))
{
  for (i in seq(n))
  {
    # look at data x_i and romove its statistics from the clustering
    zi <- zs[i]
    tables[zi] <- tables[zi] - 1
    if (tables[zi] == 0) {
      K <- K - 1
      zs[zs > zi] <- zs[zs > zi] - 1
      tables <- tables[-zi]
      mu.prior <- mu.prior[-zi, ]
    }
    
    # compute posterior probabilitites P(z_i \mid z_-i, ...)
    no_i <- seq(n)[-i]
    probs <- sapply(seq(K), function(k) {
      crp <- sum(zs[no_i] == k) / (n + alpha - 1)
      lik <- mvtnorm::dmvnorm(X[i, ], mu.prior[k,], sigma.prior)
      crp * lik
    })
    
    # compute probability for opening up a new one
    crp <- alpha / (n + alpha - 1)
    lik <- mvtnorm::dmvnorm(X[i, ], mu0, sigma.prior + sigma0)
    probs <- c(probs, crp * lik)
    probs <- probs / sum(probs)
    
    # sample new z_i according to the conditional posterior above
    z_new <- which.max(probs)
    if (z_new > K) {
      K <- K + 1
      tables <- c(tables, 0)
      mu.prior <- rbind(mu.prior, mvtnorm::rmvnorm(1, mu0, sigma0))
    }
    zs[i] <- z_new        
    tables[z_new] <- tables[z_new] + 1
    
    # compute conditional posterior P(mu \mid ...)
    for(k in seq(K)) {
      Xk <- X[zs == k, ,drop=FALSE]
      lambda <- solve(q.prior + tables[k] * q.prior)
      nominator <- tables[k] * q.prior %*% apply(Xk, 2, mean)
      mu.prior[k, ] <- mvtnorm::rmvnorm(1, lambda %*% nominator, lambda)
    }
  }
}

Let's see if that worked out!

In [14]:
data.frame(X=X, Z=as.factor(zs)) %>%
    ggplot() +
    geom_point(aes(X.1, X.2, col=Z)) +
    theme_minimal()

Cool, except for the lone guy on top the clustering worked nicely.

Stick breaking construction

In order to make the DPMM with stick-breaking work in Stan, we need to supply a maximum number of clusters $K$ from which we can choose. Setting $K=n$ would mean that we allow that every data point defines its own cluster. For the sake of the exercise I'll set it the maximum number of clusters to $10$. The hyperparameter $\alpha$ parameterizes the Beta-distribution which we use to sample stick lengths. We use the same data we already generated above.

In [15]:
K <- 10
alpha <- 2

The model is a bit more verbose in comparison to the finite case (Bayesian mixture models). We only need to add the stick breaking part in the transformed parameters, the rest stays the same. We again use the LKJ prior for the correlation matrix of the single components and set a fixed prior scale of $1$. In order to get nice, unimodel posteriors, we also introduce an ordering of the mean values.

In [16]:
stan.file <- "_models/dirichlet_process_mixture.stan"
cat(readLines(stan.file), sep="\n")
data {
	int<lower=0> K;
	int<lower=0> n;
	int<lower=1> p;
	row_vector[p] x[n];
	real alpha;
}

parameters {    	
  	ordered[p] mu[K];
	cholesky_factor_corr[p] L;
	real <lower=0, upper=1> nu[K];
}

transformed parameters {
  simplex[K] pi;
  pi[1] = nu[1];
  for(j in 2:(K-1)) 
  {
      pi[j] = nu[j] * (1 - nu[j - 1]) * pi[j - 1] / nu[j - 1]; 
  }

  pi[K] = 1 - sum(pi[1:(K - 1)]);
}

model {
  	real mix[K];

  	L ~ lkj_corr_cholesky(5);
  	nu ~ beta(1, alpha);	
	for (i in 1:K) 
	{
		mu[i] ~ normal(0, 5);
	}

  
  	for(i in 1:n) 
  	{
		for(k in 1:K) 
		{
			mix[k] = log(pi[k]) + multi_normal_cholesky_lpdf(x[i] | mu[k], L);
		}
		target += log_sum_exp(mix);
  	}
}
In [17]:
fit <- stan(stan.file, data = list(K=K, n=n, x=X, p=p, alpha=alpha), iter = 10000, warmup = 1000, chains = 1)
SAMPLING FOR MODEL 'dirichlet_process_mixture' NOW (CHAIN 1).
Chain 1: 
Chain 1: Gradient evaluation took 0.000845 seconds
Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 8.45 seconds.
Chain 1: Adjust your expectations accordingly!
Chain 1: 
Chain 1: 
Chain 1: Iteration:    1 / 10000 [  0%]  (Warmup)
Chain 1: Iteration: 1000 / 10000 [ 10%]  (Warmup)
Chain 1: Iteration: 1001 / 10000 [ 10%]  (Sampling)
Chain 1: Iteration: 2000 / 10000 [ 20%]  (Sampling)
Chain 1: Iteration: 3000 / 10000 [ 30%]  (Sampling)
Chain 1: Iteration: 4000 / 10000 [ 40%]  (Sampling)
Chain 1: Iteration: 5000 / 10000 [ 50%]  (Sampling)
Chain 1: Iteration: 6000 / 10000 [ 60%]  (Sampling)
Chain 1: Iteration: 7000 / 10000 [ 70%]  (Sampling)
Chain 1: Iteration: 8000 / 10000 [ 80%]  (Sampling)
Chain 1: Iteration: 9000 / 10000 [ 90%]  (Sampling)
Chain 1: Iteration: 10000 / 10000 [100%]  (Sampling)
Chain 1: 
Chain 1:  Elapsed Time: 129.032 seconds (Warm-up)
Chain 1:                361.426 seconds (Sampling)
Chain 1:                490.457 seconds (Total)
Chain 1: 
In [18]:
fit
Inference for Stan model: dirichlet_process_mixture.
1 chains, each with iter=10000; warmup=1000; thin=1; 
post-warmup draws per chain=9000, total post-warmup draws=9000.

            mean se_mean   sd    2.5%     25%     50%     75%   97.5% n_eff
mu[1,1]    -0.40    0.00 0.04   -0.48   -0.43   -0.40   -0.37   -0.32  4824
mu[1,2]    -0.37    0.00 0.04   -0.45   -0.40   -0.37   -0.34   -0.29  6646
mu[2,1]    -3.83    0.00 0.16   -4.13   -3.94   -3.83   -3.72   -3.52  3802
mu[2,2]     0.83    0.00 0.16    0.52    0.72    0.83    0.94    1.14  4326
mu[3,1]     1.31    0.00 0.06    1.19    1.27    1.31    1.35    1.43  7955
mu[3,2]     1.36    0.00 0.06    1.24    1.32    1.36    1.40    1.48  7130
mu[4,1]    -2.92    0.05 3.96  -11.19   -5.21   -2.94   -0.33    4.66  6931
mu[4,2]     2.54    0.04 3.86   -4.64    0.10    2.10    5.02   10.67  8998
mu[5,1]    -2.80    0.05 3.94  -10.86   -5.27   -2.72   -0.20    4.89  6896
mu[5,2]     2.66    0.05 4.06   -4.88    0.00    2.32    5.27   11.13  6487
mu[6,1]    -2.75    0.06 4.01  -10.89   -5.28   -2.72   -0.04    5.05  5280
mu[6,2]     2.78    0.04 3.97   -4.72    0.16    2.52    5.34   11.02  9468
mu[7,1]    -2.78    0.05 4.06  -10.99   -5.39   -2.71   -0.08    5.08  6357
mu[7,2]     2.85    0.05 4.10   -4.87    0.06    2.67    5.55   11.16  6168
mu[8,1]    -2.77    0.05 4.09  -11.17   -5.41   -2.68    0.00    5.03  5898
mu[8,2]     2.76    0.05 4.11   -5.10    0.00    2.62    5.43   11.09  6646
mu[9,1]    -2.87    0.06 4.16  -11.28   -5.61   -2.76   -0.05    4.94  5026
mu[9,2]     2.77    0.04 4.07   -4.85    0.00    2.64    5.46   11.13  8978
mu[10,1]   -2.78    0.05 4.11  -11.17   -5.49   -2.75    0.02    4.94  5819
mu[10,2]    2.86    0.05 4.06   -4.78    0.09    2.69    5.53   11.12  8017
L[1,1]      1.00     NaN 0.00    1.00    1.00    1.00    1.00    1.00   NaN
L[1,2]      0.00     NaN 0.00    0.00    0.00    0.00    0.00    0.00   NaN
L[2,1]     -0.88    0.00 0.02   -0.91   -0.89   -0.88   -0.87   -0.84  6392
L[2,2]      0.48    0.00 0.03    0.42    0.46    0.48    0.50    0.55  6371
nu[1]       0.43    0.00 0.05    0.33    0.39    0.43    0.46    0.52  8297
nu[2]       0.66    0.00 0.06    0.52    0.62    0.66    0.70    0.77  1675
nu[3]       0.88    0.00 0.09    0.65    0.84    0.90    0.94    0.98   684
nu[4]       0.35    0.01 0.25    0.01    0.13    0.30    0.53    0.89  1700
nu[5]       0.34    0.00 0.24    0.01    0.13    0.29    0.50    0.86  6359
nu[6]       0.34    0.00 0.24    0.01    0.14    0.29    0.51    0.83  7177
nu[7]       0.33    0.00 0.23    0.01    0.13    0.29    0.50    0.84  6252
nu[8]       0.33    0.00 0.24    0.01    0.13    0.29    0.50    0.85  7674
nu[9]       0.34    0.00 0.24    0.01    0.14    0.30    0.51    0.85  6661
nu[10]      0.34    0.00 0.24    0.01    0.13    0.29    0.51    0.85  3142
pi[1]       0.43    0.00 0.05    0.33    0.39    0.43    0.46    0.52  8297
pi[2]       0.38    0.00 0.05    0.28    0.34    0.38    0.41    0.47  2450
pi[3]       0.17    0.00 0.04    0.11    0.15    0.17    0.20    0.25  6441
pi[4]       0.01    0.00 0.01    0.00    0.00    0.00    0.01    0.05   217
pi[5]       0.01    0.00 0.01    0.00    0.00    0.00    0.01    0.03  1981
pi[6]       0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.02  3094
pi[7]       0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.01  6097
pi[8]       0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.01  3677
pi[9]       0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.01  5116
pi[10]      0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.01  4733
lp__     -406.17    0.09 4.44 -415.70 -408.94 -405.81 -403.05 -398.50  2575
         Rhat
mu[1,1]  1.00
mu[1,2]  1.00
mu[2,1]  1.00
mu[2,2]  1.00
mu[3,1]  1.00
mu[3,2]  1.00
mu[4,1]  1.00
mu[4,2]  1.00
mu[5,1]  1.00
mu[5,2]  1.00
mu[6,1]  1.00
mu[6,2]  1.00
mu[7,1]  1.00
mu[7,2]  1.00
mu[8,1]  1.00
mu[8,2]  1.00
mu[9,1]  1.00
mu[9,2]  1.00
mu[10,1] 1.00
mu[10,2] 1.00
L[1,1]    NaN
L[1,2]    NaN
L[2,1]   1.00
L[2,2]   1.00
nu[1]    1.00
nu[2]    1.00
nu[3]    1.00
nu[4]    1.00
nu[5]    1.00
nu[6]    1.00
nu[7]    1.00
nu[8]    1.00
nu[9]    1.00
nu[10]   1.00
pi[1]    1.00
pi[2]    1.00
pi[3]    1.00
pi[4]    1.01
pi[5]    1.00
pi[6]    1.00
pi[7]    1.00
pi[8]    1.00
pi[9]    1.00
pi[10]   1.00
lp__     1.00

Samples were drawn using NUTS(diag_e) at Fri Mar  1 12:00:24 2019.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

High effective sample sizes, no divergent transitions and $\hat{R}$s of one looks good. Our model seems well specified! Let's look at some plots though. First the traces for the means and mixing weights.

In [19]:
posterior <- extract(fit)
In [20]:
options(repr.fig.width=4, repr.plot.height=8)
data.frame(posterior$pi)  %>% 
    set_names(paste0("PI_", 1:10)) %>%
    tidyr::gather(key, value) %>%
    ggplot(aes(x = value, y = ..density.., fill=key), position="dodge") + 
    facet_grid(key ~ ., scales="free") +
    geom_histogram(bins=50) +
    theme_minimal()

From the plot above it looks as if Stan believes it's sufficient to use three components as the means of the mixing weights of the seven other components are fairly low or even zero. However, let's extract all means of the posterior means and assign each data point to a cluster.

In [21]:
post.mus <- do.call(
    "rbind", 
    lapply(1:10, function(i) apply(posterior$mu[,i,], 2, mean)))
In [22]:
probs <- purrr::map_dfc(seq(10), function(i) {
    mvtnorm::dmvnorm(X, post.mus[i,], diag(2))}) %>% 
    set_names(paste0("Z", seq(10)))
In [23]:
zs.stan <- apply(probs, 1, which.max)

And the final plot:

In [24]:
options(repr.fig.width=4, repr.plot.height=3)
In [25]:
data.frame(X=X, Z=as.factor(zs.stan)) %>%
    ggplot() +
    geom_point(aes(X.1, X.2, col=Z)) +
    theme_minimal()

Cool, our small hack using Stan and stick-breaking worked even better than our CRP implementation. Here, we managed to give every point its correct label.

Multivariate Bernoullis

Next, we derive a Dirichlet process mixture for a multivariate Bernoulli distribution (or whatever name is more suitable here). We again use the CRP with a Gibbs sampler and the stick-breaking construction using Stan.

CRP

We model every observation $\mathbf{x} \in \{0, 1 \}^p$ as:

\begin{align*} \pi_{{z_i}_j} & \sim \text{Beta}(a, b)\\ z_i \mid z_{1:i-1} & \sim \text{CRP} \\ x_{i, j} \mid z_i & \sim \text{Bernoulli}(\pi_{{z_i}_j}), \end{align*}

with hyperparameters $a$ and $b$. Thus the likelihood for one datum is the product over $p$ independent Bernoullis:

$$P(\mathbf{x} \mid z, \boldsymbol \Pi) = \prod_{j=1}^p \pi_{{z_i}_j}^{x_j} \cdot (1 - \pi_{{z_i}_j})^{(1 - x_j)}$$

First we again generate some data. We sample $100$ points $\mathbf{X}$ from a $p=5$ dimensional mixture of a $k=3$ multivariate Bernoullis and their latent class assignments $\mathbf{z}$ (since we already saw how one can generate data using a CRP).

In [26]:
n <- 200
p <- 3
alpha <- 0.5
k <- 3
Z <- sample(1:k, n, replace = T)
table(Z)
Z
 1  2  3 
67 66 67 

Then we generate the true success probabilities for every Bernoulli. These are $k \cdot p$ many. We simulate an easy scenario where every dimension has the same probability for every class.

In [27]:
probs.true <- matrix(seq(0.1, 0.9, length.out=k), k, p)
probs.true
0.10.10.1
0.50.50.5
0.90.90.9

Then we generate the data using these probabilities randomly:

In [28]:
probs.matrix <- probs.true[Z, ]
X  <- (probs.matrix > matrix(runif(n * p), n, p)) * 1L
head(X)
100
000
000
000
111
111

Let's have a look at it using $t$-SNE. Note, that since the data are Bernoulli, axes are kinda hard to interpret and we shouldn't get a clear separation of the clusters as in the Gaussian case (the same is true for PCA, too).

In [29]:
tsne.data <- tsne(X, perplexity = 50, max_iter = 1500)
sigma summary: Min. : 2.98023223876953e-08 |1st Qu. : 2.98023223876953e-08 |Median : 2.98023223876953e-08 |Mean : 0.328735702315581 |3rd Qu. : 0.87608321405121 |Max. : 0.881540922682334 |
Epoch: Iteration #100 error is: 10.5341680925935
Epoch: Iteration #200 error is: 0.114879642075053
Epoch: Iteration #300 error is: 0.107652843585134
Epoch: Iteration #400 error is: 0.107640277108229
Epoch: Iteration #500 error is: 0.107640268222389
Epoch: Iteration #600 error is: 0.107640267970035
Epoch: Iteration #700 error is: 0.107640267854842
Epoch: Iteration #800 error is: 0.107640267767748
Epoch: Iteration #900 error is: 0.107640267688465
Epoch: Iteration #1000 error is: 0.107640267615484
Epoch: Iteration #1100 error is: 0.107640267553738
Epoch: Iteration #1200 error is: 0.107640267497753
Epoch: Iteration #1300 error is: 0.107640267445983
Epoch: Iteration #1400 error is: 0.107640267399936
Epoch: Iteration #1500 error is: 0.107640267359697
In [30]:
plot(tsne.data, col=Z)

We use a concentration of $\alpha=1$ for the CRP and $a=b=1$ as hyperparameters for the Beta to get a somewhat uniform shape.

In [31]:
alpha <- 1
a <- b <- 1

As for the Gaussian case, we need to compute the likelihood of $\mathbf{x}_i$ evaluated at its cluster $k$. Since we factorize over $p$ independent Bernoullis, we need to write down this likelihood manually:

In [32]:
# likelihood for existing cluster
ppd <- function(x, ps)
{

  exp(sum(x * log(ps) + (1 - x) * (1 - log(ps))))
}

# likelihood for random new cluster
ppde <- function(x)
{
  ps <- rbeta(p, a, b)
  ppd(x, ps)
}

For the sample we start with the following initial parameter settings: one cluster $K=1$, i.e. all latent class assignments are $\mathbf{z} = \{1 \}^n$, and $p$ random samples from a Beta for the success probabilities of the Bernoullis determining the cluster.

In [33]:
K <- 1
zs <- rep(K, n)
tables <- n
priors <- array(rbeta(p, a, b), dim = c(1, p, 1))
priors
  1. 0.151305898092687
  2. 0.909731121268123
  3. 0.82791749550961

Then we implement the Gibbs sampler (or rather the ECM):

In [34]:
for (it in seq(100))
{
  for (i in seq(n))
  {
    # look at data x_i and romove its statistics from the clustering
    zi <- zs[i]
    tables[zi] <- tables[zi] - 1
    if (tables[zi] == 0) {
      K <- K - 1
      zs[zs > zi] <- zs[zs > zi] - 1
      tables <- tables[-zi]
      priors <- priors[,,-zi,drop=FALSE]
    }

    # compute posterior probabilitites P(z_i \mid z_-i, ...)
    no_i <- seq(n)[-i]
    probs <- sapply(seq(K), function(k) {
      crp <- sum(zs[no_i] == k) / (n + alpha - 1)
      lik <- ppd(X[i, ], priors[,,k])
      crp * lik
    })

    # compute probability for opening up a new one
    crp <- alpha / (n + alpha - 1)
    lik <- ppde(X[i, ])
    probs <- c(probs, crp * lik)
    probs <- probs / sum(probs)

    # sample new z_i according to the conditional posterior above
    z_new <- which.max(probs)
    if (z_new > K) {
      K <- K + 1
      tables <- c(tables, 0)
      priors <- abind::abind(priors, array(rbeta(p, a, b), dim=c(1, p, 1)))
    }
    zs[i] <- z_new
    tables[z_new] <- tables[z_new] + 1
    # compute conditional posterior P(mu \mid ...)
    for(k in seq(K))
    {
      Xk <- X[zs == k, ,drop=FALSE]
      priors[,,k] <- sapply(colSums(Xk), function(i) rbeta(1, i + 1, nrow(Xk) - i + 1))
    }
  }
}

Let's compare the inferred assignments with the true ones.

In [35]:
par(mfrow=c(1, 2))
plot(tsne.data, col=Z)
plot(tsne.data, col=zs)

Stick-breaking construction

As above, we try to implement the mixture using Stan, too. We use a truncated DP again and set the maximum number of clusters $K =5$, because the model is quite hard to compute and we don't use a lot of data, and the concentration to $\alpha=2$.

In [36]:
K <- 5
alpha <- 2

The respective Stan file is similar to the Gaussian case. There is one more tricky part, though. In the Gaussian case, we needed to order the mean vectors such that we can identifiable posteriors. For the binary case, we in addition need to make sure that the parameters are probabilities, i.e. have a domain from $0$ to $1$. We do that by first declaring a $K \times p$-dimensional parameter rates which we order: ordered[p] rates[K]. Then, in order to make a probability out of it, we apply the inverse logit function for every element: prob = inv_logit(rates). That should do the trick. The complete Stan file is shown below.

In [37]:
stan.file <- "_models/binary_dirichlet_process_mixture.stan"
cat(readLines(stan.file), sep="\n")
data {
	int<lower=1> K;
	int<lower=1> n;
	int<lower=1> p;
	int<lower=0,upper=1> x[n, p];
	real<lower=1> alpha;
}

parameters {    	
  	ordered[p] rates[K];
	real<lower=0, upper=1> nu[K];	
}

transformed parameters {
  simplex[K] pi;
  vector<lower=0, upper=1>[p] prob[K];

  pi[1] = nu[1];
  for(j in 2:(K-1)) 
  {
      pi[j] = nu[j] * (1 - nu[j - 1]) * pi[j - 1] / nu[j - 1]; 
  }

  pi[K] = 1 - sum(pi[1:(K - 1)]);
  for (k in 1:K) 
  {
  	for (ps in 1:p) 
  	{
  		prob[k, ps] = inv_logit(rates[k, ps]);  
  	}
  }   
}

model {
  	real mix[K];
  	nu ~ beta(1, alpha);	
  
  	for(i in 1:n) 
  	{
		for(k in 1:K) 
		{
			mix[k] = log(pi[k]);
			for (ps in 1:p)
			{
				mix[k] += bernoulli_lpmf(x[i, ps] | prob[k, ps]);
			}
		}
		target += log_sum_exp(mix);
  	}
}
In [38]:
fit <- stan(stan.file, data = list(K=K, n=n, x=matrix(as.integer(X), n, p), p=p, alpha=alpha),
            iter = 12000, warmup = 2000, chains = 1, control = list(adapt_delta = 0.99))
hash mismatch so recompiling; make sure Stan code ends with a blank line
SAMPLING FOR MODEL 'binary_dirichlet_process_mixture' NOW (CHAIN 1).
Chain 1: 
Chain 1: Gradient evaluation took 0.000431 seconds
Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 4.31 seconds.
Chain 1: Adjust your expectations accordingly!
Chain 1: 
Chain 1: 
Chain 1: Iteration:     1 / 12000 [  0%]  (Warmup)
Chain 1: Iteration:  1200 / 12000 [ 10%]  (Warmup)
Chain 1: Iteration:  2001 / 12000 [ 16%]  (Sampling)
Chain 1: Iteration:  3200 / 12000 [ 26%]  (Sampling)
Chain 1: Iteration:  4400 / 12000 [ 36%]  (Sampling)
Chain 1: Iteration:  5600 / 12000 [ 46%]  (Sampling)
Chain 1: Iteration:  6800 / 12000 [ 56%]  (Sampling)
Chain 1: Iteration:  8000 / 12000 [ 66%]  (Sampling)
Chain 1: Iteration:  9200 / 12000 [ 76%]  (Sampling)
Chain 1: Iteration: 10400 / 12000 [ 86%]  (Sampling)
Chain 1: Iteration: 11600 / 12000 [ 96%]  (Sampling)
Chain 1: Iteration: 12000 / 12000 [100%]  (Sampling)
Chain 1: 
Chain 1:  Elapsed Time: 101.608 seconds (Warm-up)
Chain 1:                2557.83 seconds (Sampling)
Chain 1:                2659.44 seconds (Total)
Chain 1: 
Warning message:
“There were 1873 divergent transitions after warmup. Increasing adapt_delta above 0.99 may help. See
http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup”Warning message:
“There were 8127 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded”Warning message:
“Examine the pairs() plot to diagnose sampling problems
”

That is unpleasant! A lot of divergent transitions and almost all of the transitions exceeded the maximum tree depth (see the Stan manual). The divergent transitions are more severe a problem than the transitions so let's look at some diagnostic plots.

In [39]:
fit
Inference for Stan model: binary_dirichlet_process_mixture.
1 chains, each with iter=12000; warmup=2000; thin=1; 
post-warmup draws per chain=10000, total post-warmup draws=10000.

              mean se_mean    sd    2.5%     25%     50%     75%   97.5% n_eff
rates[1,1]   -0.24    0.06  0.22   -0.70   -0.39   -0.24   -0.09    0.21    15
rates[1,2]   -0.04    0.08  0.27   -0.64   -0.21   -0.02    0.15    0.42    12
rates[1,3]    0.02    0.08  0.28   -0.57   -0.16    0.04    0.21    0.53    13
rates[2,1]  -41.03   14.23 22.24  -78.46  -59.41  -40.12  -19.71  -10.36     2
rates[2,2]  -31.83    9.33 18.03  -71.97  -43.00  -27.37  -18.92   -7.30     4
rates[2,3]  -18.72    5.90 14.48  -56.57  -24.71  -14.24   -7.80   -3.36     6
rates[3,1]  -10.89   11.09 19.34  -48.01  -25.62  -12.53    6.75   20.63     3
rates[3,2]    3.78   12.85 22.06  -33.19  -16.41    4.73   24.57   34.55     3
rates[3,3]   10.79   11.42 20.01  -25.96   -6.86   17.62   29.58   35.92     3
rates[4,1]  -59.64    5.63 15.40  -88.10  -68.41  -60.86  -55.09  -22.39     7
rates[4,2]  -29.59   13.20 29.39  -76.24  -54.26  -32.01   -2.97   24.18     5
rates[4,3]   -9.62   11.32 25.54  -57.00  -28.04  -10.72   12.35   34.58     5
rates[5,1]   12.24    1.93  7.34    2.26    6.53   10.36   17.46   29.67    14
rates[5,2]   15.33    2.42  7.48    4.01    9.28   14.21   21.35   31.04    10
rates[5,3]   27.77    0.69  6.10   14.90   23.35   28.58   33.00   36.37    78
nu[1]         0.48    0.01  0.05    0.38    0.46    0.49    0.51    0.56    17
nu[2]         0.35    0.06  0.12    0.17    0.24    0.35    0.44    0.57     4
nu[3]         0.21    0.07  0.18    0.02    0.06    0.17    0.30    0.72     6
nu[4]         0.21    0.03  0.12    0.03    0.11    0.20    0.30    0.45    12
nu[5]         0.28    0.05  0.18    0.04    0.14    0.25    0.40    0.67    14
pi[1]         0.48    0.01  0.05    0.38    0.46    0.49    0.51    0.56    17
pi[2]         0.18    0.03  0.07    0.08    0.12    0.18    0.24    0.31     4
pi[3]         0.06    0.02  0.05    0.01    0.02    0.06    0.09    0.18     8
pi[4]         0.06    0.02  0.05    0.00    0.02    0.05    0.10    0.17     8
pi[5]         0.21    0.02  0.06    0.06    0.18    0.22    0.25    0.31     8
prob[1,1]     0.44    0.01  0.05    0.33    0.40    0.44    0.48    0.55    15
prob[1,2]     0.49    0.02  0.07    0.35    0.45    0.49    0.54    0.60    12
prob[1,3]     0.51    0.02  0.07    0.36    0.46    0.51    0.55    0.63    13
prob[2,1]     0.00    0.00  0.00    0.00    0.00    0.00    0.00    0.00    24
prob[2,2]     0.00    0.00  0.00    0.00    0.00    0.00    0.00    0.00   157
prob[2,3]     0.00    0.00  0.01    0.00    0.00    0.00    0.00    0.03   755
prob[3,1]     0.37    0.27  0.47    0.00    0.00    0.00    1.00    1.00     3
prob[3,2]     0.55    0.28  0.49    0.00    0.00    0.99    1.00    1.00     3
prob[3,3]     0.61    0.27  0.48    0.00    0.00    1.00    1.00    1.00     3
prob[4,1]     0.00    0.00  0.00    0.00    0.00    0.00    0.00    0.00   343
prob[4,2]     0.23    0.19  0.41    0.00    0.00    0.00    0.05    1.00     5
prob[4,3]     0.30    0.24  0.45    0.00    0.00    0.00    1.00    1.00     4
prob[5,1]     0.99    0.00  0.05    0.91    1.00    1.00    1.00    1.00    85
prob[5,2]     1.00    0.00  0.01    0.98    1.00    1.00    1.00    1.00   148
prob[5,3]     1.00    0.00  0.00    1.00    1.00    1.00    1.00    1.00   348
lp__       -358.00    1.00  2.91 -363.63 -360.14 -358.00 -355.75 -352.76     8
           Rhat
rates[1,1] 1.03
rates[1,2] 1.08
rates[1,3] 1.08
rates[2,1] 3.06
rates[2,2] 1.72
rates[2,3] 1.35
rates[3,1] 2.52
rates[3,2] 2.30
rates[3,3] 2.12
rates[4,1] 1.41
rates[4,2] 1.40
rates[4,3] 1.40
rates[5,1] 1.02
rates[5,2] 1.05
rates[5,3] 1.04
nu[1]      1.15
nu[2]      1.96
nu[3]      1.28
nu[4]      1.22
nu[5]      1.06
pi[1]      1.15
pi[2]      2.02
pi[3]      1.20
pi[4]      1.34
pi[5]      1.29
prob[1,1]  1.03
prob[1,2]  1.08
prob[1,3]  1.08
prob[2,1]  1.05
prob[2,2]  1.01
prob[2,3]  1.00
prob[3,1]  2.11
prob[3,2]  2.40
prob[3,3]  2.22
prob[4,1]  1.00
prob[4,2]  1.38
prob[4,3]  1.57
prob[5,1]  1.02
prob[5,2]  1.02
prob[5,3]  1.00
lp__       1.23

Samples were drawn using NUTS(diag_e) at Fri Mar  1 12:46:17 2019.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).
In [40]:
posterior_cp_pi <- as.array(fit, pars = c("pi"))
posterior_cp_prob <- as.array(fit, pars = c("prob"))
np_cp <- nuts_params(fit)
In [41]:
mcmc_trace(posterior_cp_pi,np = np_cp)

Overall the transitions divergence at allmost all points, the traces are not consequently not very nice either. Let's also look at the effective sample size.

In [42]:
ratios_cp <- neff_ratio(fit)
mcmc_neff(ratios_cp)

Yuk! Most of the effective sample sizes are extremely low. Before we go on, we should try to change our model, since often badly setup models cause unpleasant posterior analytics. In the case that Stan doesn't like non-parametrics (the truncated stick-breaking), we could test the same model and fix the number of clusters.

In [43]:
stan.file <- "_models/binary_mixture.stan"
cat(readLines(stan.file), sep="\n")
data {
	int<lower=1> K;
	int<lower=1> n;
	int<lower=1> p;	
	int<lower=0,upper=1> x[n, p];
	vector<lower=0>[K] alpha;
}

parameters {    	
  ordered[p] rates[K];	
  simplex[K] pi;
}

transformed parameters {  
  vector<lower=0, upper=1>[p] prob[K];

  for (k in 1:K) 
  {
  	for (ps in 1:p) 
  	{
  		prob[k, ps] = inv_logit(rates[k, ps]);  
  	}
  }   
}

model {
  	real mix[K];  	
    pi ~ dirichlet(alpha);
  
  	for(i in 1:n) 
  	{
		for(k in 1:K) 
		{
			mix[k] = log(pi[k]);
			for (ps in 1:p)
			{
				mix[k] += bernoulli_lpmf(x[i, ps] | prob[k, ps]);
			}
		}
		target += log_sum_exp(mix);
  	}
}
In [47]:
fit_fixed_K <- stan(stan.file, data = list(K=K, n=n, x=X, p=p, alpha=rep(1.0, K)), 
                    iter = 10000, warmup = 1000, chains = 1, control = list(adapt_delta = 0.99))
SAMPLING FOR MODEL 'binary_mixture' NOW (CHAIN 1).
Chain 1: 
Chain 1: Gradient evaluation took 0.000325 seconds
Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 3.25 seconds.
Chain 1: Adjust your expectations accordingly!
Chain 1: 
Chain 1: 
Chain 1: Iteration:    1 / 10000 [  0%]  (Warmup)
Chain 1: Iteration: 1000 / 10000 [ 10%]  (Warmup)
Chain 1: Iteration: 1001 / 10000 [ 10%]  (Sampling)
Chain 1: Iteration: 2000 / 10000 [ 20%]  (Sampling)
Chain 1: Iteration: 3000 / 10000 [ 30%]  (Sampling)
Chain 1: Iteration: 4000 / 10000 [ 40%]  (Sampling)
Chain 1: Iteration: 5000 / 10000 [ 50%]  (Sampling)
Chain 1: Iteration: 6000 / 10000 [ 60%]  (Sampling)
Chain 1: Iteration: 7000 / 10000 [ 70%]  (Sampling)
Chain 1: Iteration: 8000 / 10000 [ 80%]  (Sampling)
Chain 1: Iteration: 9000 / 10000 [ 90%]  (Sampling)
Chain 1: Iteration: 10000 / 10000 [100%]  (Sampling)
Chain 1: 
Chain 1:  Elapsed Time: 46.5187 seconds (Warm-up)
Chain 1:                785.536 seconds (Sampling)
Chain 1:                832.055 seconds (Total)
Chain 1: 
Warning message:
“There were 8673 divergent transitions after warmup. Increasing adapt_delta above 0.99 may help. See
http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup”Warning message:
“There were 235 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded”Warning message:
“Examine the pairs() plot to diagnose sampling problems
”
In [48]:
fit_fixed_K
Inference for Stan model: binary_mixture.
1 chains, each with iter=10000; warmup=1000; thin=1; 
post-warmup draws per chain=9000, total post-warmup draws=9000.

              mean se_mean     sd    2.5%     25%     50%     75%   97.5% n_eff
rates[1,1] -290.13  106.55 183.54 -500.22 -444.41 -357.05  -73.30   18.60     3
rates[1,2] -192.72   72.04 150.88 -441.52 -320.67 -208.29  -43.91   28.82     4
rates[1,3]  -90.81   38.02 106.25 -335.29 -161.16  -64.79   -7.99   35.86     8
rates[2,1] -204.38   96.89 151.38 -402.90 -343.11 -277.98  -33.23   21.32     2
rates[2,2] -129.93   65.80 121.15 -358.00 -233.50 -122.27  -18.63   31.11     3
rates[2,3]  -65.22   37.07  88.97 -285.94 -113.61  -41.04    6.48   35.32     6
rates[3,1]   10.36    1.64   5.79    2.10    5.98    9.87   13.62   24.47    12
rates[3,2]   19.11    0.69   7.34    6.44   13.60   18.37   24.49   34.07   113
rates[3,3]   27.16    0.59   7.11   10.31   22.51   28.53   33.10   36.34   146
rates[4,1]   -0.28    0.02   0.26   -0.82   -0.45   -0.27   -0.09    0.19   128
rates[4,2]   -0.04    0.02   0.25   -0.53   -0.21   -0.04    0.12    0.44   173
rates[4,3]    0.04    0.02   0.25   -0.46   -0.13    0.04    0.21    0.52   169
rates[5,1] -317.91  142.92 229.78 -702.01 -522.26 -209.27  -99.68  -37.97     3
rates[5,2] -217.58   97.06 181.71 -607.87 -376.39 -137.15  -64.00  -17.91     4
rates[5,3] -108.69   51.76 126.31 -442.64 -166.99  -51.82  -19.75    8.30     6
pi[1]         0.09    0.00   0.06    0.00    0.04    0.08    0.13    0.23   225
pi[2]         0.08    0.01   0.06    0.00    0.03    0.07    0.12    0.23   117
pi[3]         0.22    0.02   0.07    0.03    0.19    0.23    0.27    0.32     9
pi[4]         0.49    0.00   0.05    0.40    0.46    0.49    0.52    0.59   490
pi[5]         0.12    0.03   0.08    0.00    0.04    0.10    0.18    0.29     9
prob[1,1]     0.17    0.15   0.36    0.00    0.00    0.00    0.00    1.00     6
prob[1,2]     0.19    0.17   0.39    0.00    0.00    0.00    0.00    1.00     5
prob[1,3]     0.20    0.17   0.40    0.00    0.00    0.00    0.00    1.00     5
prob[2,1]     0.21    0.19   0.40    0.00    0.00    0.00    0.00    1.00     5
prob[2,2]     0.24    0.20   0.42    0.00    0.00    0.00    0.00    1.00     4
prob[2,3]     0.26    0.20   0.43    0.00    0.00    0.00    1.00    1.00     5
prob[3,1]     0.99    0.00   0.04    0.89    1.00    1.00    1.00    1.00    63
prob[3,2]     1.00    0.00   0.00    1.00    1.00    1.00    1.00    1.00  1042
prob[3,3]     1.00    0.00   0.00    1.00    1.00    1.00    1.00    1.00   420
prob[4,1]     0.43    0.01   0.06    0.31    0.39    0.43    0.48    0.55   129
prob[4,2]     0.49    0.00   0.06    0.37    0.45    0.49    0.53    0.61   173
prob[4,3]     0.51    0.00   0.06    0.39    0.47    0.51    0.55    0.63   169
prob[5,1]     0.00    0.00   0.00    0.00    0.00    0.00    0.00    0.00  1868
prob[5,2]     0.00    0.00   0.06    0.00    0.00    0.00    0.00    0.00   559
prob[5,3]     0.03    0.02   0.18    0.00    0.00    0.00    0.00    1.00    60
lp__       -346.05    3.57   6.83 -359.50 -351.98 -344.82 -340.44 -335.84     4
           Rhat
rates[1,1] 1.93
rates[1,2] 1.50
rates[1,3] 1.29
rates[2,1] 2.94
rates[2,2] 1.84
rates[2,3] 1.45
rates[3,1] 1.07
rates[3,2] 1.00
rates[3,3] 1.01
rates[4,1] 1.02
rates[4,2] 1.02
rates[4,3] 1.02
rates[5,1] 3.57
rates[5,2] 2.15
rates[5,3] 1.52
pi[1]      1.00
pi[2]      1.03
pi[3]      1.16
pi[4]      1.00
pi[5]      1.02
prob[1,1]  1.24
prob[1,2]  1.26
prob[1,3]  1.26
prob[2,1]  1.32
prob[2,2]  1.36
prob[2,3]  1.32
prob[3,1]  1.04
prob[3,2]  1.00
prob[3,3]  1.00
prob[4,1]  1.02
prob[4,2]  1.02
prob[4,3]  1.02
prob[5,1]  1.00
prob[5,2]  1.00
prob[5,3]  1.04
lp__       1.90

Samples were drawn using NUTS(diag_e) at Fri Mar  1 13:05:12 2019.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

That didn't seem to help at all. So the problem is apparently stemming from the multivariate Bernoulli, which makes somewhat sense, since binary data isn't very talkative in the first place, so making informed inferences on these data sets is difficult. For binary data, collapsing the parameters and only using latent class assignments seems to be preferrable. Thus, the next notebook will be on efficient inference of class assignments using particle MCMC.