Mixture Density Network

Mixture Density Networks (MDNs) [1] are mixture models in which the parameters such as means, covariances and mixture proportions are learned by a neural network. MDNs combine structured data representation (a density mixture) with unstructured parameter inference (MLP neural network). Mixture Density Network learn the mixture parameters by maximizing the log-likelihood or equivalently minimizing a negative log-likelihood loss.

Assuming a Gaussian Mixture Model (GMM) with $K$ components, we can write down the probability of a test data point $y_i$ conditioned on training data $x$ as follows:

\begin{equation} p(y_i|x) = \sum_{k=1}^{K}\pi_k(x) N\big(y_i| \mu_k(x), \Sigma_k(x)\big) \end{equation}

where the parameters $\mu_k, \sigma_k, \pi_k$ are learned by a neural network (e.g. a Multi-Layer Perceptron (MLP)) parameterixed by $\theta$:

\begin{equation} \mu_k, \sigma_k, \pi_k = \mathrm{NN}(x; \theta) \end{equation}

As a result, the Neural Network (NN) is a multi-output model, subject to the following constraints on the ouput:

\begin{eqnarray} \forall k \sigma_{k}(x) &>& 0\\ \sum_{k=1}^{K} \pi_k(x) &=& 1 \end{eqnarray}

The first constraint can be achieved by using exponential activations, while the second constrained can be achieved by using softmax activations. Finally, by making use of the iid assumption, we want to minimize the following loss function:

\begin{eqnarray} \min_{\theta} L(\theta) &=& \mathrm{NLLLoss}(\theta) = - \log \prod_{i=1}^{n} p(y_i|x) = -\sum_{i=1}^{n} \log p(y_i|x) \\ &=& - \sum_{i=1}^{n}\log \bigg[\sum_{k=1}^{K} \pi_k(x_i, \theta) N(y_i|\mu_k(x_i,\theta), \Sigma_{k}(x_i,\theta)) \bigg] \end{eqnarray}

In the example below, we assume an isotropic covariance $\Sigma_k = \sigma_k^{2} I$, thus we can write a $d$-dimensional Gaussian as a product:

\begin{equation} N(y_i | \mu_k, \Sigma_k) = \frac{1}{(2\pi)^{d/2}|\Sigma_k|^{1/2}} \exp \bigg(-\frac{1}{2}(y_i-\mu_k)^{T}\Sigma_{k}^{-1}(y_i-\mu_k) \bigg) \end{equation}

\begin{equation} N(y_i | \mu_k, \Sigma_k) = \frac{1}{(2\pi \sigma_{k}^{2})^{d/2}} \exp \bigg[-\frac{1}{2\sigma_{k}^{2}} \sum_{d=1}^{D}(y_{i,d} - \mu_{k,d})^2\bigg] \end{equation}

\begin{equation} N(y_i | \mu_k, \Sigma_k) = \prod_{d=1}^{D} \frac{1}{\sigma_k \sqrt{2\pi}} \exp \bigg[-\frac{1}{2\sigma_{k}^{2}}(y_{i,d}-\mu_{k,d})^{2} \bigg] \end{equation}

Let's implement a Gaussian MDN using Keras [3]!

In [2]:
%matplotlib inline
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.cm as cm
import matplotlib.pyplot as plt

import math
import tensorflow as tf

import keras
from keras import optimizers
from keras import backend as K
from keras import regularizers
from keras.models import Sequential, Model
from keras.layers import concatenate, Input
from keras.layers import Dense, Activation, Dropout, Flatten
from keras.layers import BatchNormalization

from keras.utils import np_utils
from keras.utils import plot_model
from keras.models import load_model

from keras.callbacks import ModelCheckpoint
from keras.callbacks import TensorBoard
from keras.callbacks import LearningRateScheduler 
from keras.callbacks import EarlyStopping

from sklearn.datasets import make_blobs
from sklearn.metrics import adjusted_rand_score
from sklearn.metrics import normalized_mutual_info_score 
from sklearn.model_selection import train_test_split

np.random.seed(0)
sns.set_style('whitegrid')
Using TensorFlow backend.
Couldn't import dot_parser, loading of dot files will not be possible.

We'll use a synthetic dataset of K=4, 2-D Gaussians for ease of visualization.

In [3]:
def generate_data(N):
    pi = np.array([0.2, 0.4, 0.3, 0.1])
    mu = [[2,2], [-2,2], [-2,-2], [2,-2]]
    std = [[0.5,0.5], [1.0,1.0], [0.5,0.5], [1.0,1.0]]
    x = np.zeros((N,2), dtype=np.float32)
    y = np.zeros((N,2), dtype=np.float32)
    z = np.zeros((N,1), dtype=np.int32)
    for n in range(N):
        k = np.argmax(np.random.multinomial(1, pi))
        x[n,:] = np.random.multivariate_normal(mu[k], np.diag(std[k]))
        y[n,:] = mu[k]
        z[n,:] = k
    #end for
    z = z.flatten()
    return x, y, z, pi, mu, std

Let's define a function for computing a D-dimensional Gaussian with isotropic covariance:

In [4]:
def tf_normal(y, mu, sigma):
    y_tile = K.stack([y]*num_clusters, axis=1) #[batch_size, K, D]
    result = y_tile - mu
    sigma_tile = K.stack([sigma]*data_dim, axis=-1) #[batch_size, K, D]
    result = result * 1.0/(sigma_tile+1e-8)
    result = -K.square(result)/2.0
    oneDivSqrtTwoPI = 1.0/math.sqrt(2*math.pi)    
    result = K.exp(result) * (1.0/(sigma_tile + 1e-8))*oneDivSqrtTwoPI
    result = K.prod(result, axis=-1)    #[batch_size, K] iid Gaussians
    return result

Finally, we'll define the negative log-likelihood loss which takes in a $d$-dimensional $y_{true}$ vector and MDN output parameters $y_{pred}$:

In [5]:
def NLLLoss(y_true, y_pred):
    out_mu = y_pred[:,:num_clusters*data_dim]
    out_sigma = y_pred[:,num_clusters*data_dim : num_clusters*(data_dim+1)] 
    out_pi = y_pred[:,num_clusters*(data_dim+1):]

    out_mu = K.reshape(out_mu, [-1, num_clusters, data_dim])

    result = tf_normal(y_true, out_mu, out_sigma)
    result = result * out_pi
    result = K.sum(result, axis=1, keepdims=True)
    result = -K.log(result + 1e-8)
    result = K.mean(result)
    return tf.maximum(result, 0)

Let's generate and visualize training and test data:

In [6]:
#generate data
X_data, y_data, z_data, pi_true, mu_true, sigma_true = generate_data(1024)

data_dim = X_data.shape[1]
num_clusters = len(mu_true)
#X_data, y_data = make_blobs(n_samples=1000, centers=num_clusters, n_features=data_dim, random_state=0)
#X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, random_state=0, test_size=0.7)

num_train = 512 
X_train, X_test, y_train, y_test = X_data[:num_train,:], X_data[num_train:,:], y_data[:num_train,:], y_data[num_train:,:]
z_train, z_test = z_data[:num_train], z_data[num_train:]

#visualize data
plt.figure()
plt.scatter(X_train[:,0], X_train[:,1], c=z_train, cmap=cm.bwr)
plt.title('training data')
plt.savefig('./figures/mdn_training_data.png')

We can now define the training parameters as well as MDN parameters and architecture:

In [7]:
#training params
batch_size = 128 
num_epochs = 128 

#model parameters
hidden_size = 32
weight_decay = 1e-4

#MDN architecture
input_data = Input(shape=(data_dim,))
x = Dense(32, activation='relu')(input_data)
x = Dropout(0.2)(x)
x = BatchNormalization()(x)
x = Dense(32, activation='relu')(x)
x = Dropout(0.2)(x)
x = BatchNormalization()(x)

mu = Dense(num_clusters * data_dim, activation='linear')(x) #cluster means
sigma = Dense(num_clusters, activation=K.exp)(x)            #diagonal cov
pi = Dense(num_clusters, activation='softmax')(x)           #mixture proportions
out = concatenate([mu, sigma, pi], axis=-1)

model = Model(input_data, out)
adam = optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
model.compile(loss=NLLLoss, optimizer=adam, metrics=['accuracy'])
model.summary()
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (None, 2)             0                                            
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 32)            96          input_1[0][0]                    
____________________________________________________________________________________________________
dropout_1 (Dropout)              (None, 32)            0           dense_1[0][0]                    
____________________________________________________________________________________________________
batch_normalization_1 (BatchNorm (None, 32)            128         dropout_1[0][0]                  
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 32)            1056        batch_normalization_1[0][0]      
____________________________________________________________________________________________________
dropout_2 (Dropout)              (None, 32)            0           dense_2[0][0]                    
____________________________________________________________________________________________________
batch_normalization_2 (BatchNorm (None, 32)            128         dropout_2[0][0]                  
____________________________________________________________________________________________________
dense_3 (Dense)                  (None, 8)             264         batch_normalization_2[0][0]      
____________________________________________________________________________________________________
dense_4 (Dense)                  (None, 4)             132         batch_normalization_2[0][0]      
____________________________________________________________________________________________________
dense_5 (Dense)                  (None, 4)             132         batch_normalization_2[0][0]      
____________________________________________________________________________________________________
concatenate_1 (Concatenate)      (None, 16)            0           dense_3[0][0]                    
                                                                   dense_4[0][0]                    
                                                                   dense_5[0][0]                    
====================================================================================================
Total params: 1,936
Trainable params: 1,808
Non-trainable params: 128
____________________________________________________________________________________________________

Let's train the model with early stopping:

In [8]:
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0.1, patience=32, verbose=1)
callbacks_list = [early_stopping]

#model training
hist = model.fit(X_train, y_train, batch_size=batch_size, epochs=num_epochs, callbacks=callbacks_list, validation_split=0.2, shuffle=True, verbose=2)
Train on 409 samples, validate on 103 samples
Epoch 1/128
1s - loss: 5.9070 - acc: 0.0342 - val_loss: 6.0739 - val_acc: 0.0000e+00
Epoch 2/128
0s - loss: 5.5468 - acc: 0.0244 - val_loss: 5.9862 - val_acc: 0.0000e+00
Epoch 3/128
0s - loss: 5.5059 - acc: 0.0293 - val_loss: 5.8783 - val_acc: 0.0000e+00
Epoch 4/128
0s - loss: 5.2641 - acc: 0.0269 - val_loss: 5.7378 - val_acc: 0.0000e+00
Epoch 5/128
0s - loss: 5.0044 - acc: 0.0220 - val_loss: 5.5889 - val_acc: 0.0097
Epoch 6/128
0s - loss: 4.7956 - acc: 0.0196 - val_loss: 5.4150 - val_acc: 0.0194
Epoch 7/128
0s - loss: 4.7717 - acc: 0.0293 - val_loss: 5.2195 - val_acc: 0.0194
Epoch 8/128
0s - loss: 4.5822 - acc: 0.0293 - val_loss: 5.0472 - val_acc: 0.0291
Epoch 9/128
0s - loss: 4.4746 - acc: 0.0293 - val_loss: 4.9046 - val_acc: 0.0291
Epoch 10/128
0s - loss: 4.3982 - acc: 0.0513 - val_loss: 4.7775 - val_acc: 0.0388
Epoch 11/128
0s - loss: 4.3942 - acc: 0.0440 - val_loss: 4.6663 - val_acc: 0.0388
Epoch 12/128
0s - loss: 4.1395 - acc: 0.0611 - val_loss: 4.5731 - val_acc: 0.0388
Epoch 13/128
0s - loss: 4.0839 - acc: 0.0416 - val_loss: 4.4729 - val_acc: 0.0485
Epoch 14/128
0s - loss: 4.2564 - acc: 0.0367 - val_loss: 4.3763 - val_acc: 0.0583
Epoch 15/128
0s - loss: 4.0929 - acc: 0.0660 - val_loss: 4.2881 - val_acc: 0.0583
Epoch 16/128
0s - loss: 3.9556 - acc: 0.0440 - val_loss: 4.2074 - val_acc: 0.0583
Epoch 17/128
0s - loss: 4.0760 - acc: 0.0342 - val_loss: 4.1286 - val_acc: 0.0583
Epoch 18/128
0s - loss: 3.9408 - acc: 0.0367 - val_loss: 4.0593 - val_acc: 0.0777
Epoch 19/128
0s - loss: 3.8878 - acc: 0.0587 - val_loss: 3.9933 - val_acc: 0.0777
Epoch 20/128
0s - loss: 3.8631 - acc: 0.0489 - val_loss: 3.9211 - val_acc: 0.0777
Epoch 21/128
0s - loss: 3.7755 - acc: 0.0587 - val_loss: 3.8560 - val_acc: 0.1068
Epoch 22/128
0s - loss: 3.6820 - acc: 0.0611 - val_loss: 3.7960 - val_acc: 0.1068
Epoch 23/128
0s - loss: 3.6805 - acc: 0.0636 - val_loss: 3.7322 - val_acc: 0.1068
Epoch 24/128
0s - loss: 3.6072 - acc: 0.0562 - val_loss: 3.6619 - val_acc: 0.1165
Epoch 25/128
0s - loss: 3.5914 - acc: 0.0611 - val_loss: 3.5933 - val_acc: 0.1165
Epoch 26/128
0s - loss: 3.5474 - acc: 0.0660 - val_loss: 3.5266 - val_acc: 0.1165
Epoch 27/128
0s - loss: 3.6166 - acc: 0.0636 - val_loss: 3.4607 - val_acc: 0.1165
Epoch 28/128
0s - loss: 3.5039 - acc: 0.0709 - val_loss: 3.4030 - val_acc: 0.1165
Epoch 29/128
0s - loss: 3.4022 - acc: 0.0562 - val_loss: 3.3472 - val_acc: 0.1165
Epoch 30/128
0s - loss: 3.3929 - acc: 0.0685 - val_loss: 3.2874 - val_acc: 0.1165
Epoch 31/128
0s - loss: 3.2822 - acc: 0.0758 - val_loss: 3.2226 - val_acc: 0.1165
Epoch 32/128
0s - loss: 3.4340 - acc: 0.0391 - val_loss: 3.1615 - val_acc: 0.1165
Epoch 33/128
0s - loss: 3.2833 - acc: 0.0660 - val_loss: 3.0998 - val_acc: 0.1165
Epoch 34/128
0s - loss: 3.3057 - acc: 0.0611 - val_loss: 3.0442 - val_acc: 0.1165
Epoch 35/128
0s - loss: 3.2296 - acc: 0.0660 - val_loss: 2.9851 - val_acc: 0.1165
Epoch 36/128
0s - loss: 3.1976 - acc: 0.0709 - val_loss: 2.9262 - val_acc: 0.1165
Epoch 37/128
0s - loss: 3.3379 - acc: 0.0416 - val_loss: 2.8677 - val_acc: 0.1165
Epoch 38/128
0s - loss: 3.1349 - acc: 0.0587 - val_loss: 2.8123 - val_acc: 0.1165
Epoch 39/128
0s - loss: 3.2889 - acc: 0.0465 - val_loss: 2.7493 - val_acc: 0.1165
Epoch 40/128
0s - loss: 3.1267 - acc: 0.0440 - val_loss: 2.6958 - val_acc: 0.1165
Epoch 41/128
0s - loss: 3.1133 - acc: 0.0391 - val_loss: 2.6460 - val_acc: 0.1165
Epoch 42/128
0s - loss: 3.0921 - acc: 0.0807 - val_loss: 2.5883 - val_acc: 0.1165
Epoch 43/128
0s - loss: 2.9775 - acc: 0.0489 - val_loss: 2.5306 - val_acc: 0.1165
Epoch 44/128
0s - loss: 3.0426 - acc: 0.0562 - val_loss: 2.4894 - val_acc: 0.1165
Epoch 45/128
0s - loss: 2.9756 - acc: 0.0636 - val_loss: 2.4565 - val_acc: 0.1165
Epoch 46/128
0s - loss: 2.9206 - acc: 0.0538 - val_loss: 2.4081 - val_acc: 0.1165
Epoch 47/128
0s - loss: 2.8931 - acc: 0.0513 - val_loss: 2.3660 - val_acc: 0.1165
Epoch 48/128
0s - loss: 2.8887 - acc: 0.0587 - val_loss: 2.3295 - val_acc: 0.1165
Epoch 49/128
0s - loss: 2.7834 - acc: 0.0562 - val_loss: 2.2978 - val_acc: 0.1165
Epoch 50/128
0s - loss: 2.7840 - acc: 0.0489 - val_loss: 2.2597 - val_acc: 0.1165
Epoch 51/128
0s - loss: 2.7646 - acc: 0.0440 - val_loss: 2.2191 - val_acc: 0.0971
Epoch 52/128
0s - loss: 2.7916 - acc: 0.0391 - val_loss: 2.1742 - val_acc: 0.0874
Epoch 53/128
0s - loss: 2.6324 - acc: 0.0538 - val_loss: 2.1436 - val_acc: 0.0680
Epoch 54/128
0s - loss: 2.6104 - acc: 0.0709 - val_loss: 2.1323 - val_acc: 0.0583
Epoch 55/128
0s - loss: 2.5474 - acc: 0.0465 - val_loss: 2.0941 - val_acc: 0.0485
Epoch 56/128
0s - loss: 2.6980 - acc: 0.0416 - val_loss: 2.0505 - val_acc: 0.0291
Epoch 57/128
0s - loss: 2.5632 - acc: 0.0538 - val_loss: 2.0112 - val_acc: 0.0194
Epoch 58/128
0s - loss: 2.4390 - acc: 0.0342 - val_loss: 1.9608 - val_acc: 0.0000e+00
Epoch 59/128
0s - loss: 2.6240 - acc: 0.0489 - val_loss: 1.9134 - val_acc: 0.0000e+00
Epoch 60/128
0s - loss: 2.6062 - acc: 0.0465 - val_loss: 1.8807 - val_acc: 0.0000e+00
Epoch 61/128
0s - loss: 2.4623 - acc: 0.0318 - val_loss: 1.8648 - val_acc: 0.0000e+00
Epoch 62/128
0s - loss: 2.4274 - acc: 0.0293 - val_loss: 1.8435 - val_acc: 0.0000e+00
Epoch 63/128
0s - loss: 2.3820 - acc: 0.0416 - val_loss: 1.8149 - val_acc: 0.0000e+00
Epoch 64/128
0s - loss: 2.5219 - acc: 0.0293 - val_loss: 1.7766 - val_acc: 0.0000e+00
Epoch 65/128
0s - loss: 2.4784 - acc: 0.0318 - val_loss: 1.7594 - val_acc: 0.0000e+00
Epoch 66/128
0s - loss: 2.3320 - acc: 0.0562 - val_loss: 1.7520 - val_acc: 0.0000e+00
Epoch 67/128
0s - loss: 2.2655 - acc: 0.0269 - val_loss: 1.6908 - val_acc: 0.0000e+00
Epoch 68/128
0s - loss: 2.2435 - acc: 0.0367 - val_loss: 1.6086 - val_acc: 0.0000e+00
Epoch 69/128
0s - loss: 2.3730 - acc: 0.0367 - val_loss: 1.5474 - val_acc: 0.0000e+00
Epoch 70/128
0s - loss: 2.1912 - acc: 0.0293 - val_loss: 1.5112 - val_acc: 0.0000e+00
Epoch 71/128
0s - loss: 2.2566 - acc: 0.0538 - val_loss: 1.4697 - val_acc: 0.0000e+00
Epoch 72/128
0s - loss: 2.0497 - acc: 0.0293 - val_loss: 1.4552 - val_acc: 0.0000e+00
Epoch 73/128
0s - loss: 2.2115 - acc: 0.0367 - val_loss: 1.4480 - val_acc: 0.0000e+00
Epoch 74/128
0s - loss: 1.9978 - acc: 0.0244 - val_loss: 1.4215 - val_acc: 0.0000e+00
Epoch 75/128
0s - loss: 1.9611 - acc: 0.0220 - val_loss: 1.3279 - val_acc: 0.0000e+00
Epoch 76/128
0s - loss: 1.9493 - acc: 0.0244 - val_loss: 1.2173 - val_acc: 0.0000e+00
Epoch 77/128
0s - loss: 2.0401 - acc: 0.0220 - val_loss: 1.1278 - val_acc: 0.0000e+00
Epoch 78/128
0s - loss: 1.9477 - acc: 0.0293 - val_loss: 1.0995 - val_acc: 0.0000e+00
Epoch 79/128
0s - loss: 1.8232 - acc: 0.0342 - val_loss: 1.0629 - val_acc: 0.0000e+00
Epoch 80/128
0s - loss: 1.8441 - acc: 0.0342 - val_loss: 1.0179 - val_acc: 0.0000e+00
Epoch 81/128
0s - loss: 1.7945 - acc: 0.0196 - val_loss: 1.0220 - val_acc: 0.0000e+00
Epoch 82/128
0s - loss: 1.7765 - acc: 0.0196 - val_loss: 1.0123 - val_acc: 0.0000e+00
Epoch 83/128
0s - loss: 1.7779 - acc: 0.0196 - val_loss: 0.9859 - val_acc: 0.0000e+00
Epoch 84/128
0s - loss: 1.8818 - acc: 0.0220 - val_loss: 0.9265 - val_acc: 0.0000e+00
Epoch 85/128
0s - loss: 1.7221 - acc: 0.0269 - val_loss: 0.9135 - val_acc: 0.0000e+00
Epoch 86/128
0s - loss: 1.6196 - acc: 0.0244 - val_loss: 0.8858 - val_acc: 0.0000e+00
Epoch 87/128
0s - loss: 1.7099 - acc: 0.0196 - val_loss: 0.9017 - val_acc: 0.0194
Epoch 88/128
0s - loss: 1.7219 - acc: 0.0220 - val_loss: 0.9613 - val_acc: 0.0194
Epoch 89/128
0s - loss: 1.5476 - acc: 0.0293 - val_loss: 0.9461 - val_acc: 0.0097
Epoch 90/128
0s - loss: 1.5125 - acc: 0.0220 - val_loss: 0.8579 - val_acc: 0.0000e+00
Epoch 91/128
0s - loss: 1.4719 - acc: 0.0293 - val_loss: 0.7834 - val_acc: 0.0097
Epoch 92/128
0s - loss: 1.4103 - acc: 0.0147 - val_loss: 0.7143 - val_acc: 0.0097
Epoch 93/128
0s - loss: 1.6006 - acc: 0.0171 - val_loss: 0.7103 - val_acc: 0.0097
Epoch 94/128
0s - loss: 1.2985 - acc: 0.0122 - val_loss: 0.7105 - val_acc: 0.0097
Epoch 95/128
0s - loss: 1.3594 - acc: 0.0244 - val_loss: 0.6380 - val_acc: 0.0097
Epoch 96/128
0s - loss: 1.3300 - acc: 0.0122 - val_loss: 0.5237 - val_acc: 0.0000e+00
Epoch 97/128
0s - loss: 1.1956 - acc: 0.0147 - val_loss: 0.4397 - val_acc: 0.0000e+00
Epoch 98/128
0s - loss: 1.1126 - acc: 0.0049 - val_loss: 0.3710 - val_acc: 0.0000e+00
Epoch 99/128
0s - loss: 1.1429 - acc: 0.0244 - val_loss: 0.3235 - val_acc: 0.0000e+00
Epoch 100/128
0s - loss: 1.3316 - acc: 0.0244 - val_loss: 0.3533 - val_acc: 0.0000e+00
Epoch 101/128
0s - loss: 1.2698 - acc: 0.0171 - val_loss: 0.3225 - val_acc: 0.0000e+00
Epoch 102/128
0s - loss: 1.0162 - acc: 0.0318 - val_loss: 0.2567 - val_acc: 0.0000e+00
Epoch 103/128
0s - loss: 1.0019 - acc: 0.0244 - val_loss: 0.1909 - val_acc: 0.0000e+00
Epoch 104/128
0s - loss: 1.1213 - acc: 0.0073 - val_loss: 0.1690 - val_acc: 0.0000e+00
Epoch 105/128
0s - loss: 1.2172 - acc: 0.0220 - val_loss: 0.1825 - val_acc: 0.0000e+00
Epoch 106/128
0s - loss: 0.9417 - acc: 0.0293 - val_loss: 0.2106 - val_acc: 0.0000e+00
Epoch 107/128
0s - loss: 1.1023 - acc: 0.0098 - val_loss: 0.2005 - val_acc: 0.0097
Epoch 108/128
0s - loss: 0.9184 - acc: 0.0269 - val_loss: 0.1490 - val_acc: 0.0097
Epoch 109/128
0s - loss: 0.9715 - acc: 0.0293 - val_loss: 0.0899 - val_acc: 0.0097
Epoch 110/128
0s - loss: 0.8439 - acc: 0.0220 - val_loss: 0.0525 - val_acc: 0.0097
Epoch 111/128
0s - loss: 0.7072 - acc: 0.0147 - val_loss: 0.0605 - val_acc: 0.0000e+00
Epoch 112/128
0s - loss: 0.6803 - acc: 0.0098 - val_loss: 0.0146 - val_acc: 0.0000e+00
Epoch 113/128
0s - loss: 0.7736 - acc: 0.0171 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
Epoch 114/128
0s - loss: 0.5061 - acc: 0.0073 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
Epoch 115/128
0s - loss: 0.7100 - acc: 0.0122 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
Epoch 116/128
0s - loss: 0.6968 - acc: 0.0196 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
Epoch 117/128
0s - loss: 0.4926 - acc: 0.0269 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
Epoch 118/128
0s - loss: 0.5244 - acc: 0.0269 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
Epoch 119/128
0s - loss: 0.5382 - acc: 0.0196 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
Epoch 120/128
0s - loss: 0.6631 - acc: 0.0171 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
Epoch 121/128
0s - loss: 0.5660 - acc: 0.0171 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
Epoch 122/128
0s - loss: 0.6036 - acc: 0.0269 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
Epoch 123/128
0s - loss: 0.3026 - acc: 0.0220 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
Epoch 124/128
0s - loss: 0.3879 - acc: 0.0244 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
Epoch 125/128
0s - loss: 0.4804 - acc: 0.0318 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
Epoch 126/128
0s - loss: 0.4000 - acc: 0.0293 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
Epoch 127/128
0s - loss: 0.3020 - acc: 0.0342 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
Epoch 128/128
0s - loss: 0.3045 - acc: 0.0171 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00

We can now make predictions on test data:

In [9]:
print "predicting on test data..."
y_pred = model.predict(X_test)

mu_pred = y_pred[:,:num_clusters*data_dim]
mu_pred = np.reshape(mu_pred, [-1, num_clusters, data_dim])
sigma_pred = y_pred[:,num_clusters*data_dim : num_clusters*(data_dim+1)] 
pi_pred = y_pred[:,num_clusters*(data_dim+1):]
z_pred = np.argmax(pi_pred, axis=-1)
predicting on test data...

We'll evaluate clustering performance by computing the adjusted rand score, normalized mutual information and by comparing true and predicted values for means and standard deviations:

In [10]:
rand_score = adjusted_rand_score(z_test, z_pred)
print "adjusted rand score: ", rand_score 

nmi_score = normalized_mutual_info_score(z_test, z_pred)
print "normalized MI score: ", nmi_score 
adjusted rand score:  0.842059401234
normalized MI score:  0.840175516041
In [11]:
mu_pred_list = []
sigma_pred_list = []
for label in np.unique(z_pred):
    z_idx = np.where(z_pred == label)[0]
    mu_pred_lbl = np.mean(mu_pred[z_idx,label,:], axis=0)
    mu_pred_list.append(mu_pred_lbl)

    sigma_pred_lbl = np.mean(sigma_pred[z_idx,label], axis=0)
    sigma_pred_list.append(sigma_pred_lbl)
#end for

print "true means: "
print np.array(mu_true)

print "predicted means: "
print np.array(mu_pred_list)

print "true sigmas: "
print np.array(sigma_true)

print "predicted sigmas: "
print np.array(sigma_pred_list)
true means: 
[[ 2  2]
 [-2  2]
 [-2 -2]
 [ 2 -2]]
predicted means: 
[[ 2.13656688 -1.97987711]
 [-1.47537231 -2.03052163]
 [ 1.95446241  2.02965617]
 [-1.944049    1.96065593]]
true sigmas: 
[[ 0.5  0.5]
 [ 1.   1. ]
 [ 0.5  0.5]
 [ 1.   1. ]]
predicted sigmas: 
[ 1.26810288  0.30546659  0.61887544  0.15010625]

Notice the clustering scores are quite high, the predicted means are also close to the true means, however, the predicted sigmas are slightly off, possibly due to a difference between training and test data.

In [12]:
#generate plots
plt.figure()
plt.scatter(X_test[:,0], X_test[:,1], c=z_pred, cmap=cm.bwr)
plt.scatter(np.array(mu_pred_list)[:,0], np.array(mu_pred_list)[:,1], s=100, marker='x', lw=4.0, color='k')
plt.title('test data')
plt.savefig('./figures/mdn_test_data.png')

The black crosses in the figure above show the predicted means overlayed with test data. You can see visually, the means are close to cluster centers.

In [13]:
plt.figure()
plt.plot(hist.history['loss'], c='b', lw=2.0, label='train')
plt.plot(hist.history['val_loss'], c='r', lw=2.0, label='val')
plt.title('Mixture Density Network')
plt.xlabel('Epochs')
plt.ylabel('Negative Log Likelihood Loss')
plt.legend(loc='upper left')
plt.savefig('./figures/mdn_loss.png')

We can see that both training and validation loss are decreasing with the number of epochs.

References

[1] C. Bishop, "Pattern Recognition and Machine Learning", 2006
[2] EdwardLib, "Mixture Density Networks", http://edwardlib.org/tutorials/mixture-density-network
[3] F. Chollet, "Keras: The Python Deep Learning library", https://keras.io/