Truncated stick breaking in greta

Here, we briefly discuss truncated stick breaking (TSB). We use TSB as a finite-dimensional alternative to an infinite dimensional prior over number of mixture components. We do the analysis in greta (cause the author is more comfortable with it). We use the TSB, because it allows us to use continuous parameters entirely, which in turn allows us to use Hamiltonian Monte Carlo (and since most probabilistic programming languages do not allow discrete paramters :)). Also, in practice, working with continuous parameters is way easier than discrete ones (as with the Chinese restaurant process).

We first try to use TSB with a mixture of univariate normals, and then with Poisson variables.

Some required libraries:

In [1]:
suppressMessages(library("greta"))
suppressMessages(library("tensorflow"))
In [2]:
suppressMessages(library("tidyverse"))
suppressMessages(library("MASS"))
suppressMessages(library("bayesplot"))
suppressMessages(library("caret"))
options(repr.plot.width=8, repr.plot.height=3)

Reproducibility!

In [3]:
options(repr.plot.width=8, repr.plot.height=3)
set.seed(23)

Normal data

We first create some data.

In [4]:
N <- 1000
K <- 3
In [5]:
mus <- c(-2, 0, 2)
sd <- .25
In [6]:
data <- vector(length = N * K)
Z <- factor(rep(seq(K), each=N))
for (k in seq(K)) {    
    idx <- seq(N) + ((k - 1) * N)
    data[idx] <- rnorm(N, mus[k], sd)
}
In [7]:
data.frame(data=data, idx=as.factor(rep(seq(K), each=N))) %>%
    ggplot(aes(data, fill=idx)) +
    geom_histogram(bins=30, position = "dodge") +
    scale_fill_viridis_d("Component", alpha = 1, begin=.3, end=.8) + 
    theme_minimal()

For our data set it should be fairly easy to find the correct number of clusters ($3$). For the TSB, we need to set the number of clusters to a sufficiently high $K$, to achieve a negligibly small error in comparison to a "true" infinite dimensional prior.

See for instance https://projecteuclid.org/euclid.bj/1551862850 for a theoretical and practical justification for the truncation.

In [8]:
K <- 10

We will only try to estimate the vector of mean values of the Gaussians. In order to avoid non-identifiability, we can use a small trick: we create a prior of means of length $K$ and, use the cumulative sum (cumsum), and ensure that the mean values are sorted and increasing in value. For instance:

In [9]:
x <- c(runif(1, -1, 1), runif(5, 0, 1))
cumsum(x)
  1. 0.593454854562879
  2. 0.856230824021623
  3. 1.22660479671322
  4. 2.02785729477182
  5. 2.67813442042097
  6. 3.13572235242464

In greta that is:

In [10]:
prior_mu_ordered <- cumsum(
    c(greta::variable(lower = -5, upper = 5),
      greta::variable(lower = 0, upper = 5, dim = K - 1)))

Then we set a prior over the mixing weights. For this, as mentioned before, we use stick breaking. Luckily LaplacesDemon has a function for the sticks.

In [11]:
stick_breaking <- function(theta) {
   LaplacesDemon::Stick(theta)
}

# note the K - 1 which is required for LaplacesDemon::Stick (yes it's dumb)
prior_stick <- greta::beta(1, 1, dim = K - 1)
prior_weights <- stick_breaking(prior_stick)

Then we set the mixture distribution and sample from the posterior. This is a little annoying in R, cause we need to do it manually.

In [12]:
greta::distribution(data) <- greta::mixture(
    greta::normal(prior_mu_ordered[1], .25),    
    greta::normal(prior_mu_ordered[2], .25),
    greta::normal(prior_mu_ordered[3], .25),
    greta::normal(prior_mu_ordered[4], .25),
    greta::normal(prior_mu_ordered[5], .25),
    greta::normal(prior_mu_ordered[6], .25),
    greta::normal(prior_mu_ordered[7], .25),
    greta::normal(prior_mu_ordered[8], .25),
    greta::normal(prior_mu_ordered[9], .25),
    greta::normal(prior_mu_ordered[10], .25),
    weights = prior_weights
)
In [13]:
mod <- greta::model(prior_stick, prior_weights, prior_mu_ordered)
In [14]:
samples <- greta::mcmc(mod, n_cores = 1, chains = 1)
    warmup ====================================== 1000/1000 | eta:  0s          
  sampling ====================================== 1000/1000 | eta:  0s          

Let's see if we could identify the components.

In [15]:
bayesplot::mcmc_hist(samples, regex_pars = "weights")
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

Looks like it suffices to use three components (as only three components of the weight posterior significant weight). Now, let's see if we could identify the means correctly. Since we only use components , we only consider these means.

In [16]:
bayesplot::mcmc_intervals(samples, regex_pars = "mu")

Nice! That worked great. We were perfectly able to recover the correct number of components. Now, let's cluster them. We only cluster with the componentes 2, 3 and 4, since our posterior weights suggest, these are the important ones.

In [17]:
posterior_matrix <- as.matrix(samples)
posterior_mus <- posterior_matrix[,sprintf("prior_mu_ordered[%i,1]", c(2, 3, 4))]

Here we assign each point its most likely assignment

In [18]:
clusters <- vector(length = length(data))
for (i in seq(data)) {
    probs <- apply(posterior_mus, 2, function(.) mean(dnorm(data[i], ., .25)))
    clusters[i] <- which.max(probs)    
}

Finally we compute a confusion matrix to check our predictions.

In [19]:
caret::confusionMatrix(factor(clusters), Z)
Confusion Matrix and Statistics

          Reference
Prediction    1    2    3
         1 1000    0    0
         2    0 1000    0
         3    0    0 1000

Overall Statistics
                                     
               Accuracy : 1          
                 95% CI : (0.9988, 1)
    No Information Rate : 0.3333     
    P-Value [Acc > NIR] : < 2.2e-16  
                                     
                  Kappa : 1          
                                     
 Mcnemar's Test P-Value : NA         

Statistics by Class:

                     Class: 1 Class: 2 Class: 3
Sensitivity            1.0000   1.0000   1.0000
Specificity            1.0000   1.0000   1.0000
Pos Pred Value         1.0000   1.0000   1.0000
Neg Pred Value         1.0000   1.0000   1.0000
Prevalence             0.3333   0.3333   0.3333
Detection Rate         0.3333   0.3333   0.3333
Detection Prevalence   0.3333   0.3333   0.3333
Balanced Accuracy      1.0000   1.0000   1.0000

Poisson data

Next we try a Poisson mixture.

In [20]:
N <- 1000
K <- 3
mus <- exp(c(0, 1, 2))
In [21]:
data <- vector(length = N * K)
Z <- factor(rep(seq(K), each=N))
for (k in seq(K)) {    
    idx <- seq(N) + ((k - 1) * N)
    data[idx] <- rpois(N, mus[k])
}
In [22]:
data.frame(data=data, idx=as.factor(rep(seq(K), each=N))) %>%
    ggplot(aes(data, fill=idx)) +
    geom_histogram(bins=30, position = "dodge") +
    scale_fill_viridis_d("Component", alpha = 1, begin=.3, end=.8) + 
    theme_minimal()

We truncate the DP at $K=10$ as before and define the model.

In [23]:
K <- 10

We can use the same the trick from before. Since we exponentiate the means later, we don't need to care about negative values.

In [24]:
prior_mu_ordered <- cumsum(
    c(greta::variable(lower = -5, upper = 3),
      greta::variable(lower = 0, upper = 3, dim = K - 1)))
In [25]:
greta::distribution(data) <- greta::mixture(
    greta::poisson(exp(prior_mu_ordered[1])),
    greta::poisson(exp(prior_mu_ordered[2])),
    greta::poisson(exp(prior_mu_ordered[3])),
    greta::poisson(exp(prior_mu_ordered[4])),
    greta::poisson(exp(prior_mu_ordered[5])),
    greta::poisson(exp(prior_mu_ordered[6])),
    greta::poisson(exp(prior_mu_ordered[7])),
    greta::poisson(exp(prior_mu_ordered[8])),
    greta::poisson(exp(prior_mu_ordered[9])),
    greta::poisson(exp(prior_mu_ordered[10])),
    weights = prior_weights
)
In [26]:
mod <- greta::model(prior_stick, prior_weights, prior_mu_ordered)
In [27]:
samples <- greta::mcmc(mod, chains = 1, n_cores = 1)
    warmup ====================================== 1000/1000 | eta:  0s          
  sampling ====================================== 1000/1000 | eta:  0s          

How many components do we need?

In [28]:
bayesplot::mcmc_hist(samples, regex_pars = "weights")
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

Here, it's not so clear anymore. It looks as if we should use four components (so one too many), but this could be due to poor prior choice. What about the means?

In [29]:
bayesplot::mcmc_intervals(samples, regex_pars = "mu")

Finally, let's have a look at the posterior assignments of the data points to the components again. In order to keep the cluster assignments (2, 3, 4, 5) the same as in our original assignment, we pick the second component last. By that when we call which.max it will return the correct indexes for the cluster, i.e. assignment to cluster 3 will be index 1 which corresponds to the cluster in our original assignment. Since we have four clusters to consider, the new cluster (component 2) will get index 4.

In [30]:
posterior_matrix <- as.matrix(samples)
posterior_mus <- posterior_matrix[,sprintf("prior_mu_ordered[%i,1]", c(3, 4, 5, 2))]
In [31]:
clusters <- vector(length = length(data))
for (i in seq(data)) {
    probs <- apply(posterior_mus, 2, function(.) mean(dpois(data[i], exp(.))))
    clusters[i] <- which.max(probs) 
}

Since we have four clusters to consider now, we need to relevel our true latent assigments.

In [32]:
caret::confusionMatrix(factor(clusters), factor(Z, levels=c(levels(Z), 4)))
Confusion Matrix and Statistics

          Reference
Prediction   1   2   3   4
         1 375 175   9   0
         2 241 625 122   0
         3   3 149 869   0
         4 381  51   0   0

Overall Statistics
                                          
               Accuracy : 0.623           
                 95% CI : (0.6054, 0.6404)
    No Information Rate : 0.3333          
    P-Value [Acc > NIR] : < 2.2e-16       
                                          
                  Kappa : 0.4725          
                                          
 Mcnemar's Test P-Value : NA              

Statistics by Class:

                     Class: 1 Class: 2 Class: 3 Class: 4
Sensitivity            0.3750   0.6250   0.8690       NA
Specificity            0.9080   0.8185   0.9240    0.856
Pos Pred Value         0.6708   0.6326   0.8511       NA
Neg Pred Value         0.7440   0.8136   0.9338       NA
Prevalence             0.3333   0.3333   0.3333    0.000
Detection Rate         0.1250   0.2083   0.2897    0.000
Detection Prevalence   0.1863   0.3293   0.3403    0.144
Balanced Accuracy      0.6415   0.7218   0.8965       NA

Obviously, the assignment suffers from the fact that the Poisson components are not very well separated. However, the clustering is still fairly good.

In [ ]: