Mixture Density Networks (MDNs) [1] are mixture models in which the parameters such as means, covariances and mixture proportions are learned by a neural network. MDNs combine structured data representation (a density mixture) with unstructured parameter inference (MLP neural network). Mixture Density Network learn the mixture parameters by maximizing the log-likelihood or equivalently minimizing a negative log-likelihood loss.
Assuming a Gaussian Mixture Model (GMM) with $K$ components, we can write down the probability of a test data point $y_i$ conditioned on training data $x$ as follows:
where the parameters $\mu_k, \sigma_k, \pi_k$ are learned by a neural network (e.g. a Multi-Layer Perceptron (MLP)) parameterixed by $\theta$:
As a result, the Neural Network (NN) is a multi-output model, subject to the following constraints on the ouput:
The first constraint can be achieved by using exponential activations, while the second constrained can be achieved by using softmax activations. Finally, by making use of the iid assumption, we want to minimize the following loss function:
In the example below, we assume an isotropic covariance $\Sigma_k = \sigma_k^{2} I$, thus we can write a $d$-dimensional Gaussian as a product:
Let's implement a Gaussian MDN using Keras [3]!
%matplotlib inline
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import math
import tensorflow as tf
import keras
from keras import optimizers
from keras import backend as K
from keras import regularizers
from keras.models import Sequential, Model
from keras.layers import concatenate, Input
from keras.layers import Dense, Activation, Dropout, Flatten
from keras.layers import BatchNormalization
from keras.utils import np_utils
from keras.utils import plot_model
from keras.models import load_model
from keras.callbacks import ModelCheckpoint
from keras.callbacks import TensorBoard
from keras.callbacks import LearningRateScheduler
from keras.callbacks import EarlyStopping
from sklearn.datasets import make_blobs
from sklearn.metrics import adjusted_rand_score
from sklearn.metrics import normalized_mutual_info_score
from sklearn.model_selection import train_test_split
np.random.seed(0)
sns.set_style('whitegrid')
Using TensorFlow backend.
Couldn't import dot_parser, loading of dot files will not be possible.
We'll use a synthetic dataset of K=4, 2-D Gaussians for ease of visualization.
def generate_data(N):
pi = np.array([0.2, 0.4, 0.3, 0.1])
mu = [[2,2], [-2,2], [-2,-2], [2,-2]]
std = [[0.5,0.5], [1.0,1.0], [0.5,0.5], [1.0,1.0]]
x = np.zeros((N,2), dtype=np.float32)
y = np.zeros((N,2), dtype=np.float32)
z = np.zeros((N,1), dtype=np.int32)
for n in range(N):
k = np.argmax(np.random.multinomial(1, pi))
x[n,:] = np.random.multivariate_normal(mu[k], np.diag(std[k]))
y[n,:] = mu[k]
z[n,:] = k
#end for
z = z.flatten()
return x, y, z, pi, mu, std
Let's define a function for computing a D-dimensional Gaussian with isotropic covariance:
def tf_normal(y, mu, sigma):
y_tile = K.stack([y]*num_clusters, axis=1) #[batch_size, K, D]
result = y_tile - mu
sigma_tile = K.stack([sigma]*data_dim, axis=-1) #[batch_size, K, D]
result = result * 1.0/(sigma_tile+1e-8)
result = -K.square(result)/2.0
oneDivSqrtTwoPI = 1.0/math.sqrt(2*math.pi)
result = K.exp(result) * (1.0/(sigma_tile + 1e-8))*oneDivSqrtTwoPI
result = K.prod(result, axis=-1) #[batch_size, K] iid Gaussians
return result
Finally, we'll define the negative log-likelihood loss which takes in a $d$-dimensional $y_{true}$ vector and MDN output parameters $y_{pred}$:
def NLLLoss(y_true, y_pred):
out_mu = y_pred[:,:num_clusters*data_dim]
out_sigma = y_pred[:,num_clusters*data_dim : num_clusters*(data_dim+1)]
out_pi = y_pred[:,num_clusters*(data_dim+1):]
out_mu = K.reshape(out_mu, [-1, num_clusters, data_dim])
result = tf_normal(y_true, out_mu, out_sigma)
result = result * out_pi
result = K.sum(result, axis=1, keepdims=True)
result = -K.log(result + 1e-8)
result = K.mean(result)
return tf.maximum(result, 0)
Let's generate and visualize training and test data:
#generate data
X_data, y_data, z_data, pi_true, mu_true, sigma_true = generate_data(1024)
data_dim = X_data.shape[1]
num_clusters = len(mu_true)
#X_data, y_data = make_blobs(n_samples=1000, centers=num_clusters, n_features=data_dim, random_state=0)
#X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, random_state=0, test_size=0.7)
num_train = 512
X_train, X_test, y_train, y_test = X_data[:num_train,:], X_data[num_train:,:], y_data[:num_train,:], y_data[num_train:,:]
z_train, z_test = z_data[:num_train], z_data[num_train:]
#visualize data
plt.figure()
plt.scatter(X_train[:,0], X_train[:,1], c=z_train, cmap=cm.bwr)
plt.title('training data')
plt.savefig('./figures/mdn_training_data.png')
We can now define the training parameters as well as MDN parameters and architecture:
#training params
batch_size = 128
num_epochs = 128
#model parameters
hidden_size = 32
weight_decay = 1e-4
#MDN architecture
input_data = Input(shape=(data_dim,))
x = Dense(32, activation='relu')(input_data)
x = Dropout(0.2)(x)
x = BatchNormalization()(x)
x = Dense(32, activation='relu')(x)
x = Dropout(0.2)(x)
x = BatchNormalization()(x)
mu = Dense(num_clusters * data_dim, activation='linear')(x) #cluster means
sigma = Dense(num_clusters, activation=K.exp)(x) #diagonal cov
pi = Dense(num_clusters, activation='softmax')(x) #mixture proportions
out = concatenate([mu, sigma, pi], axis=-1)
model = Model(input_data, out)
adam = optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
model.compile(loss=NLLLoss, optimizer=adam, metrics=['accuracy'])
model.summary()
____________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ==================================================================================================== input_1 (InputLayer) (None, 2) 0 ____________________________________________________________________________________________________ dense_1 (Dense) (None, 32) 96 input_1[0][0] ____________________________________________________________________________________________________ dropout_1 (Dropout) (None, 32) 0 dense_1[0][0] ____________________________________________________________________________________________________ batch_normalization_1 (BatchNorm (None, 32) 128 dropout_1[0][0] ____________________________________________________________________________________________________ dense_2 (Dense) (None, 32) 1056 batch_normalization_1[0][0] ____________________________________________________________________________________________________ dropout_2 (Dropout) (None, 32) 0 dense_2[0][0] ____________________________________________________________________________________________________ batch_normalization_2 (BatchNorm (None, 32) 128 dropout_2[0][0] ____________________________________________________________________________________________________ dense_3 (Dense) (None, 8) 264 batch_normalization_2[0][0] ____________________________________________________________________________________________________ dense_4 (Dense) (None, 4) 132 batch_normalization_2[0][0] ____________________________________________________________________________________________________ dense_5 (Dense) (None, 4) 132 batch_normalization_2[0][0] ____________________________________________________________________________________________________ concatenate_1 (Concatenate) (None, 16) 0 dense_3[0][0] dense_4[0][0] dense_5[0][0] ==================================================================================================== Total params: 1,936 Trainable params: 1,808 Non-trainable params: 128 ____________________________________________________________________________________________________
Let's train the model with early stopping:
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0.1, patience=32, verbose=1)
callbacks_list = [early_stopping]
#model training
hist = model.fit(X_train, y_train, batch_size=batch_size, epochs=num_epochs, callbacks=callbacks_list, validation_split=0.2, shuffle=True, verbose=2)
Train on 409 samples, validate on 103 samples Epoch 1/128 1s - loss: 5.9070 - acc: 0.0342 - val_loss: 6.0739 - val_acc: 0.0000e+00 Epoch 2/128 0s - loss: 5.5468 - acc: 0.0244 - val_loss: 5.9862 - val_acc: 0.0000e+00 Epoch 3/128 0s - loss: 5.5059 - acc: 0.0293 - val_loss: 5.8783 - val_acc: 0.0000e+00 Epoch 4/128 0s - loss: 5.2641 - acc: 0.0269 - val_loss: 5.7378 - val_acc: 0.0000e+00 Epoch 5/128 0s - loss: 5.0044 - acc: 0.0220 - val_loss: 5.5889 - val_acc: 0.0097 Epoch 6/128 0s - loss: 4.7956 - acc: 0.0196 - val_loss: 5.4150 - val_acc: 0.0194 Epoch 7/128 0s - loss: 4.7717 - acc: 0.0293 - val_loss: 5.2195 - val_acc: 0.0194 Epoch 8/128 0s - loss: 4.5822 - acc: 0.0293 - val_loss: 5.0472 - val_acc: 0.0291 Epoch 9/128 0s - loss: 4.4746 - acc: 0.0293 - val_loss: 4.9046 - val_acc: 0.0291 Epoch 10/128 0s - loss: 4.3982 - acc: 0.0513 - val_loss: 4.7775 - val_acc: 0.0388 Epoch 11/128 0s - loss: 4.3942 - acc: 0.0440 - val_loss: 4.6663 - val_acc: 0.0388 Epoch 12/128 0s - loss: 4.1395 - acc: 0.0611 - val_loss: 4.5731 - val_acc: 0.0388 Epoch 13/128 0s - loss: 4.0839 - acc: 0.0416 - val_loss: 4.4729 - val_acc: 0.0485 Epoch 14/128 0s - loss: 4.2564 - acc: 0.0367 - val_loss: 4.3763 - val_acc: 0.0583 Epoch 15/128 0s - loss: 4.0929 - acc: 0.0660 - val_loss: 4.2881 - val_acc: 0.0583 Epoch 16/128 0s - loss: 3.9556 - acc: 0.0440 - val_loss: 4.2074 - val_acc: 0.0583 Epoch 17/128 0s - loss: 4.0760 - acc: 0.0342 - val_loss: 4.1286 - val_acc: 0.0583 Epoch 18/128 0s - loss: 3.9408 - acc: 0.0367 - val_loss: 4.0593 - val_acc: 0.0777 Epoch 19/128 0s - loss: 3.8878 - acc: 0.0587 - val_loss: 3.9933 - val_acc: 0.0777 Epoch 20/128 0s - loss: 3.8631 - acc: 0.0489 - val_loss: 3.9211 - val_acc: 0.0777 Epoch 21/128 0s - loss: 3.7755 - acc: 0.0587 - val_loss: 3.8560 - val_acc: 0.1068 Epoch 22/128 0s - loss: 3.6820 - acc: 0.0611 - val_loss: 3.7960 - val_acc: 0.1068 Epoch 23/128 0s - loss: 3.6805 - acc: 0.0636 - val_loss: 3.7322 - val_acc: 0.1068 Epoch 24/128 0s - loss: 3.6072 - acc: 0.0562 - val_loss: 3.6619 - val_acc: 0.1165 Epoch 25/128 0s - loss: 3.5914 - acc: 0.0611 - val_loss: 3.5933 - val_acc: 0.1165 Epoch 26/128 0s - loss: 3.5474 - acc: 0.0660 - val_loss: 3.5266 - val_acc: 0.1165 Epoch 27/128 0s - loss: 3.6166 - acc: 0.0636 - val_loss: 3.4607 - val_acc: 0.1165 Epoch 28/128 0s - loss: 3.5039 - acc: 0.0709 - val_loss: 3.4030 - val_acc: 0.1165 Epoch 29/128 0s - loss: 3.4022 - acc: 0.0562 - val_loss: 3.3472 - val_acc: 0.1165 Epoch 30/128 0s - loss: 3.3929 - acc: 0.0685 - val_loss: 3.2874 - val_acc: 0.1165 Epoch 31/128 0s - loss: 3.2822 - acc: 0.0758 - val_loss: 3.2226 - val_acc: 0.1165 Epoch 32/128 0s - loss: 3.4340 - acc: 0.0391 - val_loss: 3.1615 - val_acc: 0.1165 Epoch 33/128 0s - loss: 3.2833 - acc: 0.0660 - val_loss: 3.0998 - val_acc: 0.1165 Epoch 34/128 0s - loss: 3.3057 - acc: 0.0611 - val_loss: 3.0442 - val_acc: 0.1165 Epoch 35/128 0s - loss: 3.2296 - acc: 0.0660 - val_loss: 2.9851 - val_acc: 0.1165 Epoch 36/128 0s - loss: 3.1976 - acc: 0.0709 - val_loss: 2.9262 - val_acc: 0.1165 Epoch 37/128 0s - loss: 3.3379 - acc: 0.0416 - val_loss: 2.8677 - val_acc: 0.1165 Epoch 38/128 0s - loss: 3.1349 - acc: 0.0587 - val_loss: 2.8123 - val_acc: 0.1165 Epoch 39/128 0s - loss: 3.2889 - acc: 0.0465 - val_loss: 2.7493 - val_acc: 0.1165 Epoch 40/128 0s - loss: 3.1267 - acc: 0.0440 - val_loss: 2.6958 - val_acc: 0.1165 Epoch 41/128 0s - loss: 3.1133 - acc: 0.0391 - val_loss: 2.6460 - val_acc: 0.1165 Epoch 42/128 0s - loss: 3.0921 - acc: 0.0807 - val_loss: 2.5883 - val_acc: 0.1165 Epoch 43/128 0s - loss: 2.9775 - acc: 0.0489 - val_loss: 2.5306 - val_acc: 0.1165 Epoch 44/128 0s - loss: 3.0426 - acc: 0.0562 - val_loss: 2.4894 - val_acc: 0.1165 Epoch 45/128 0s - loss: 2.9756 - acc: 0.0636 - val_loss: 2.4565 - val_acc: 0.1165 Epoch 46/128 0s - loss: 2.9206 - acc: 0.0538 - val_loss: 2.4081 - val_acc: 0.1165 Epoch 47/128 0s - loss: 2.8931 - acc: 0.0513 - val_loss: 2.3660 - val_acc: 0.1165 Epoch 48/128 0s - loss: 2.8887 - acc: 0.0587 - val_loss: 2.3295 - val_acc: 0.1165 Epoch 49/128 0s - loss: 2.7834 - acc: 0.0562 - val_loss: 2.2978 - val_acc: 0.1165 Epoch 50/128 0s - loss: 2.7840 - acc: 0.0489 - val_loss: 2.2597 - val_acc: 0.1165 Epoch 51/128 0s - loss: 2.7646 - acc: 0.0440 - val_loss: 2.2191 - val_acc: 0.0971 Epoch 52/128 0s - loss: 2.7916 - acc: 0.0391 - val_loss: 2.1742 - val_acc: 0.0874 Epoch 53/128 0s - loss: 2.6324 - acc: 0.0538 - val_loss: 2.1436 - val_acc: 0.0680 Epoch 54/128 0s - loss: 2.6104 - acc: 0.0709 - val_loss: 2.1323 - val_acc: 0.0583 Epoch 55/128 0s - loss: 2.5474 - acc: 0.0465 - val_loss: 2.0941 - val_acc: 0.0485 Epoch 56/128 0s - loss: 2.6980 - acc: 0.0416 - val_loss: 2.0505 - val_acc: 0.0291 Epoch 57/128 0s - loss: 2.5632 - acc: 0.0538 - val_loss: 2.0112 - val_acc: 0.0194 Epoch 58/128 0s - loss: 2.4390 - acc: 0.0342 - val_loss: 1.9608 - val_acc: 0.0000e+00 Epoch 59/128 0s - loss: 2.6240 - acc: 0.0489 - val_loss: 1.9134 - val_acc: 0.0000e+00 Epoch 60/128 0s - loss: 2.6062 - acc: 0.0465 - val_loss: 1.8807 - val_acc: 0.0000e+00 Epoch 61/128 0s - loss: 2.4623 - acc: 0.0318 - val_loss: 1.8648 - val_acc: 0.0000e+00 Epoch 62/128 0s - loss: 2.4274 - acc: 0.0293 - val_loss: 1.8435 - val_acc: 0.0000e+00 Epoch 63/128 0s - loss: 2.3820 - acc: 0.0416 - val_loss: 1.8149 - val_acc: 0.0000e+00 Epoch 64/128 0s - loss: 2.5219 - acc: 0.0293 - val_loss: 1.7766 - val_acc: 0.0000e+00 Epoch 65/128 0s - loss: 2.4784 - acc: 0.0318 - val_loss: 1.7594 - val_acc: 0.0000e+00 Epoch 66/128 0s - loss: 2.3320 - acc: 0.0562 - val_loss: 1.7520 - val_acc: 0.0000e+00 Epoch 67/128 0s - loss: 2.2655 - acc: 0.0269 - val_loss: 1.6908 - val_acc: 0.0000e+00 Epoch 68/128 0s - loss: 2.2435 - acc: 0.0367 - val_loss: 1.6086 - val_acc: 0.0000e+00 Epoch 69/128 0s - loss: 2.3730 - acc: 0.0367 - val_loss: 1.5474 - val_acc: 0.0000e+00 Epoch 70/128 0s - loss: 2.1912 - acc: 0.0293 - val_loss: 1.5112 - val_acc: 0.0000e+00 Epoch 71/128 0s - loss: 2.2566 - acc: 0.0538 - val_loss: 1.4697 - val_acc: 0.0000e+00 Epoch 72/128 0s - loss: 2.0497 - acc: 0.0293 - val_loss: 1.4552 - val_acc: 0.0000e+00 Epoch 73/128 0s - loss: 2.2115 - acc: 0.0367 - val_loss: 1.4480 - val_acc: 0.0000e+00 Epoch 74/128 0s - loss: 1.9978 - acc: 0.0244 - val_loss: 1.4215 - val_acc: 0.0000e+00 Epoch 75/128 0s - loss: 1.9611 - acc: 0.0220 - val_loss: 1.3279 - val_acc: 0.0000e+00 Epoch 76/128 0s - loss: 1.9493 - acc: 0.0244 - val_loss: 1.2173 - val_acc: 0.0000e+00 Epoch 77/128 0s - loss: 2.0401 - acc: 0.0220 - val_loss: 1.1278 - val_acc: 0.0000e+00 Epoch 78/128 0s - loss: 1.9477 - acc: 0.0293 - val_loss: 1.0995 - val_acc: 0.0000e+00 Epoch 79/128 0s - loss: 1.8232 - acc: 0.0342 - val_loss: 1.0629 - val_acc: 0.0000e+00 Epoch 80/128 0s - loss: 1.8441 - acc: 0.0342 - val_loss: 1.0179 - val_acc: 0.0000e+00 Epoch 81/128 0s - loss: 1.7945 - acc: 0.0196 - val_loss: 1.0220 - val_acc: 0.0000e+00 Epoch 82/128 0s - loss: 1.7765 - acc: 0.0196 - val_loss: 1.0123 - val_acc: 0.0000e+00 Epoch 83/128 0s - loss: 1.7779 - acc: 0.0196 - val_loss: 0.9859 - val_acc: 0.0000e+00 Epoch 84/128 0s - loss: 1.8818 - acc: 0.0220 - val_loss: 0.9265 - val_acc: 0.0000e+00 Epoch 85/128 0s - loss: 1.7221 - acc: 0.0269 - val_loss: 0.9135 - val_acc: 0.0000e+00 Epoch 86/128 0s - loss: 1.6196 - acc: 0.0244 - val_loss: 0.8858 - val_acc: 0.0000e+00 Epoch 87/128 0s - loss: 1.7099 - acc: 0.0196 - val_loss: 0.9017 - val_acc: 0.0194 Epoch 88/128 0s - loss: 1.7219 - acc: 0.0220 - val_loss: 0.9613 - val_acc: 0.0194 Epoch 89/128 0s - loss: 1.5476 - acc: 0.0293 - val_loss: 0.9461 - val_acc: 0.0097 Epoch 90/128 0s - loss: 1.5125 - acc: 0.0220 - val_loss: 0.8579 - val_acc: 0.0000e+00 Epoch 91/128 0s - loss: 1.4719 - acc: 0.0293 - val_loss: 0.7834 - val_acc: 0.0097 Epoch 92/128 0s - loss: 1.4103 - acc: 0.0147 - val_loss: 0.7143 - val_acc: 0.0097 Epoch 93/128 0s - loss: 1.6006 - acc: 0.0171 - val_loss: 0.7103 - val_acc: 0.0097 Epoch 94/128 0s - loss: 1.2985 - acc: 0.0122 - val_loss: 0.7105 - val_acc: 0.0097 Epoch 95/128 0s - loss: 1.3594 - acc: 0.0244 - val_loss: 0.6380 - val_acc: 0.0097 Epoch 96/128 0s - loss: 1.3300 - acc: 0.0122 - val_loss: 0.5237 - val_acc: 0.0000e+00 Epoch 97/128 0s - loss: 1.1956 - acc: 0.0147 - val_loss: 0.4397 - val_acc: 0.0000e+00 Epoch 98/128 0s - loss: 1.1126 - acc: 0.0049 - val_loss: 0.3710 - val_acc: 0.0000e+00 Epoch 99/128 0s - loss: 1.1429 - acc: 0.0244 - val_loss: 0.3235 - val_acc: 0.0000e+00 Epoch 100/128 0s - loss: 1.3316 - acc: 0.0244 - val_loss: 0.3533 - val_acc: 0.0000e+00 Epoch 101/128 0s - loss: 1.2698 - acc: 0.0171 - val_loss: 0.3225 - val_acc: 0.0000e+00 Epoch 102/128 0s - loss: 1.0162 - acc: 0.0318 - val_loss: 0.2567 - val_acc: 0.0000e+00 Epoch 103/128 0s - loss: 1.0019 - acc: 0.0244 - val_loss: 0.1909 - val_acc: 0.0000e+00 Epoch 104/128 0s - loss: 1.1213 - acc: 0.0073 - val_loss: 0.1690 - val_acc: 0.0000e+00 Epoch 105/128 0s - loss: 1.2172 - acc: 0.0220 - val_loss: 0.1825 - val_acc: 0.0000e+00 Epoch 106/128 0s - loss: 0.9417 - acc: 0.0293 - val_loss: 0.2106 - val_acc: 0.0000e+00 Epoch 107/128 0s - loss: 1.1023 - acc: 0.0098 - val_loss: 0.2005 - val_acc: 0.0097 Epoch 108/128 0s - loss: 0.9184 - acc: 0.0269 - val_loss: 0.1490 - val_acc: 0.0097 Epoch 109/128 0s - loss: 0.9715 - acc: 0.0293 - val_loss: 0.0899 - val_acc: 0.0097 Epoch 110/128 0s - loss: 0.8439 - acc: 0.0220 - val_loss: 0.0525 - val_acc: 0.0097 Epoch 111/128 0s - loss: 0.7072 - acc: 0.0147 - val_loss: 0.0605 - val_acc: 0.0000e+00 Epoch 112/128 0s - loss: 0.6803 - acc: 0.0098 - val_loss: 0.0146 - val_acc: 0.0000e+00 Epoch 113/128 0s - loss: 0.7736 - acc: 0.0171 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00 Epoch 114/128 0s - loss: 0.5061 - acc: 0.0073 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00 Epoch 115/128 0s - loss: 0.7100 - acc: 0.0122 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00 Epoch 116/128 0s - loss: 0.6968 - acc: 0.0196 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00 Epoch 117/128 0s - loss: 0.4926 - acc: 0.0269 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00 Epoch 118/128 0s - loss: 0.5244 - acc: 0.0269 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00 Epoch 119/128 0s - loss: 0.5382 - acc: 0.0196 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00 Epoch 120/128 0s - loss: 0.6631 - acc: 0.0171 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00 Epoch 121/128 0s - loss: 0.5660 - acc: 0.0171 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00 Epoch 122/128 0s - loss: 0.6036 - acc: 0.0269 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00 Epoch 123/128 0s - loss: 0.3026 - acc: 0.0220 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00 Epoch 124/128 0s - loss: 0.3879 - acc: 0.0244 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00 Epoch 125/128 0s - loss: 0.4804 - acc: 0.0318 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00 Epoch 126/128 0s - loss: 0.4000 - acc: 0.0293 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00 Epoch 127/128 0s - loss: 0.3020 - acc: 0.0342 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00 Epoch 128/128 0s - loss: 0.3045 - acc: 0.0171 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00
We can now make predictions on test data:
print "predicting on test data..."
y_pred = model.predict(X_test)
mu_pred = y_pred[:,:num_clusters*data_dim]
mu_pred = np.reshape(mu_pred, [-1, num_clusters, data_dim])
sigma_pred = y_pred[:,num_clusters*data_dim : num_clusters*(data_dim+1)]
pi_pred = y_pred[:,num_clusters*(data_dim+1):]
z_pred = np.argmax(pi_pred, axis=-1)
predicting on test data...
We'll evaluate clustering performance by computing the adjusted rand score, normalized mutual information and by comparing true and predicted values for means and standard deviations:
rand_score = adjusted_rand_score(z_test, z_pred)
print "adjusted rand score: ", rand_score
nmi_score = normalized_mutual_info_score(z_test, z_pred)
print "normalized MI score: ", nmi_score
adjusted rand score: 0.842059401234 normalized MI score: 0.840175516041
mu_pred_list = []
sigma_pred_list = []
for label in np.unique(z_pred):
z_idx = np.where(z_pred == label)[0]
mu_pred_lbl = np.mean(mu_pred[z_idx,label,:], axis=0)
mu_pred_list.append(mu_pred_lbl)
sigma_pred_lbl = np.mean(sigma_pred[z_idx,label], axis=0)
sigma_pred_list.append(sigma_pred_lbl)
#end for
print "true means: "
print np.array(mu_true)
print "predicted means: "
print np.array(mu_pred_list)
print "true sigmas: "
print np.array(sigma_true)
print "predicted sigmas: "
print np.array(sigma_pred_list)
true means: [[ 2 2] [-2 2] [-2 -2] [ 2 -2]] predicted means: [[ 2.13656688 -1.97987711] [-1.47537231 -2.03052163] [ 1.95446241 2.02965617] [-1.944049 1.96065593]] true sigmas: [[ 0.5 0.5] [ 1. 1. ] [ 0.5 0.5] [ 1. 1. ]] predicted sigmas: [ 1.26810288 0.30546659 0.61887544 0.15010625]
Notice the clustering scores are quite high, the predicted means are also close to the true means, however, the predicted sigmas are slightly off, possibly due to a difference between training and test data.
#generate plots
plt.figure()
plt.scatter(X_test[:,0], X_test[:,1], c=z_pred, cmap=cm.bwr)
plt.scatter(np.array(mu_pred_list)[:,0], np.array(mu_pred_list)[:,1], s=100, marker='x', lw=4.0, color='k')
plt.title('test data')
plt.savefig('./figures/mdn_test_data.png')
The black crosses in the figure above show the predicted means overlayed with test data. You can see visually, the means are close to cluster centers.
plt.figure()
plt.plot(hist.history['loss'], c='b', lw=2.0, label='train')
plt.plot(hist.history['val_loss'], c='r', lw=2.0, label='val')
plt.title('Mixture Density Network')
plt.xlabel('Epochs')
plt.ylabel('Negative Log Likelihood Loss')
plt.legend(loc='upper left')
plt.savefig('./figures/mdn_loss.png')
We can see that both training and validation loss are decreasing with the number of epochs.
[1] C. Bishop, "Pattern Recognition and Machine Learning", 2006
[2] EdwardLib, "Mixture Density Networks", http://edwardlib.org/tutorials/mixture-density-network
[3] F. Chollet, "Keras: The Python Deep Learning library", https://keras.io/