# Gaussian Mixture Models¶

## 1. Load packages & helper functions¶

In [ ]:
# Distributions.jl is for distribution supports and Turing.jl is our PPL.
using Distributions, Turing

TPATH = Pkg.dir("Turing")

# gmm.helper.jl contains functions to build histogram plots
include(TPATH*"/example-models/nips-2017/gmm.helper.jl");


Function below takes $x$ from Gibbs and NUTS chains and make a histogram plot with exact density.

## 2. Defing the model¶

In [ ]:
@model GMM(p, μ, σ) = begin
z ~ Categorical(p)
x ~ Normal(μ[z], σ[z])
end

@model cGMM(p, μ, σ) = begin
x ~ UnivariateGMM2(μ, σ, Categorical(p));
end

M = 5
p = [ 0.2,  0.2,   0.2, 0.2,  0.2]

s = [-0.5, -1.5, -0.75,  -2, -0.5]; σ = exp(s);

μ1 = [   0,    1,     2, 3.5, 4.25] + 2.5 * collect(0:4);
μ2 = [   0,    1,     2, 3.5, 4.25] + 0.5 * collect(0:4);

In [21]:
p1 = plot([make_norm_pdf(p, μ1, σ)], -20, 20, Theme(default_color=colors[1]),
Guide.xlabel(nothing), Guide.ylabel("Density"), Guide.title("MoG with Seprated Mixtures"));
p2 = plot([make_norm_pdf(p, μ2, σ)], -20, 20, Theme(default_color=colors[2]),
Guide.xlabel(nothing), Guide.ylabel("Density"), Guide.title("MoG with Nearby Mixtures"));

In [22]:
draw(PNG(15cm, 10cm), vstack(p1, p2));


## 3. MCMC sampling¶

### Setting up the MCMC engines¶

In [ ]:
N = 10000; K = 500;
gibbs = Gibbs(round(Int,N/K), PG(10, 1, :z), HMC(K-1, 0.2, 4, :x); thin=false)
nuts  = NUTS(N, 0.65);


### Sampling from MoG with seprated mixtures¶

In [ ]:
println("Running Gibbs")
chain_gibbs1 = sample(GMM(p, μ1, σ), gibbs)
x_gibbs1 = map(x_arr -> x_arr[1], chain_gibbs1[:x]);

println("Running NUTS")
chain_nuts1 = sample(cGMM(p, μ1, σ), nuts)
x_nuts1 = map(x_arr -> x_arr[1], chain_nuts1[:x]);

In [23]:
visualize(x_gibbs1, x_nuts1, μ1);


### Sampling from MoG with nearby mixtures¶

In [ ]:
println("Running Gibbs")
chain_gibbs2 = sample(GMM(p, μ2, σ), gibbs)
x_gibbs2 = map(x_arr -> x_arr[1], chain_gibbs2[:x]);

println("Running NUTS")
chain_nuts2 = sample(cGMM(p, μ2, σ), nuts)
x_nuts2 = map(x_arr -> x_arr[1], chain_nuts2[:x]);

In [24]:
visualize(x_gibbs2, x_nuts2, μ2, -2, 7);