This is one of the 100 recipes of the IPython Cookbook, the definitive guide to high-performance scientific computing and data science in Python.

8.8. Detecting hidden structures in a dataset with clustering

  1. Let's import the libraries.
In [ ]:
from itertools import permutations
import numpy as np
import sklearn
import sklearn.decomposition as dec
import sklearn.cluster as clu
import sklearn.datasets as ds
import sklearn.grid_search as gs
import matplotlib.pyplot as plt
%matplotlib inline
  1. Let's generate a random dataset with three clusters.
In [ ]:
X, y = ds.make_blobs(n_samples=200, n_features=2, centers=3)
  1. We will need a couple of functions to relabel and display the results of the clustering algorithms.
In [ ]:
def relabel(cl):
    """Relabel a clustering with three clusters
    to match the original classes."""
    if np.max(cl) != 2:
        return cl
    perms = np.array(list(permutations((0, 1, 2))))
    i = np.argmin([np.sum(np.abs(perm[cl] - y))
                   for perm in perms])
    p = perms[i]
    return p[cl]
In [ ]:
def display_clustering(labels, title):
    """Plot the data points with the cluster colors."""
    # We relabel the classes when there are 3 clusters.
    labels = relabel(labels)
    # Display the points with the true labels on the left, 
    # and with the clustering labels on the right.
    for i, (c, title) in enumerate(zip(
            [y, labels], ["True labels", title])):
        plt.subplot(121 + i);
        plt.scatter(X[:,0], X[:,1], c=c, s=30, 
        plt.xticks([]); plt.yticks([]);
  1. Now, we cluster the dataset with the K-means algorithm, a classic and simple clustering algorithm.
In [ ]:
km = clu.KMeans();
display_clustering(km.labels_, "KMeans")
  1. This algorithm requires the number of clusters at initialization time. In general, however, we do not necessarily now the number of clusters in the dataset. Here, let's try with n_clusters=3 (that's cheating, because we happen to know that there are 3 clusters!).
In [ ]:
km = clu.KMeans(n_clusters=3);
display_clustering(km.labels_, "KMeans(3)")
  1. Let's try a few other clustering algorithms implemented in scikit-learn. The simplicity of the API makes it really easy to try different methods: it is just a matter of changing the name of the class.
In [ ]:
plt.scatter(X[:,0], X[:,1], c=y, s=30,
plt.xticks([]); plt.yticks([]);
plt.title("True labels");
for i, est in enumerate([
    c = relabel(est.labels_)
    plt.subplot(232 + i);
    plt.scatter(X[:,0], X[:,1], c=c, s=30,
    plt.xticks([]); plt.yticks([]);

The first two algorithms required the number of clusters as input. The next two did not, but they were able to find the right number 3. The last two failed at finding the correct number of clusters (overclustering).

You'll find all the explanations, figures, references, and much more in the book (to be released later this summer).

IPython Cookbook, by Cyrille Rossant, Packt Publishing, 2014 (500 pages).