#!/usr/bin/env python # coding: utf-8 # > This is one of the 100 recipes of the [IPython Cookbook](http://ipython-books.github.io/), the definitive guide to high-performance scientific computing and data science in Python. # # # 8.8. Detecting hidden structures in a dataset with clustering # 1. Let's import the libraries. # In[ ]: from itertools import permutations import numpy as np import sklearn import sklearn.decomposition as dec import sklearn.cluster as clu import sklearn.datasets as ds import sklearn.grid_search as gs import matplotlib.pyplot as plt get_ipython().run_line_magic('matplotlib', 'inline') # 2. Let's generate a random dataset with three clusters. # In[ ]: X, y = ds.make_blobs(n_samples=200, n_features=2, centers=3) # 3. We will need a couple of functions to relabel and display the results of the clustering algorithms. # In[ ]: def relabel(cl): """Relabel a clustering with three clusters to match the original classes.""" if np.max(cl) != 2: return cl perms = np.array(list(permutations((0, 1, 2)))) i = np.argmin([np.sum(np.abs(perm[cl] - y)) for perm in perms]) p = perms[i] return p[cl] # In[ ]: def display_clustering(labels, title): """Plot the data points with the cluster colors.""" # We relabel the classes when there are 3 clusters. labels = relabel(labels) plt.figure(figsize=(8,3)); # Display the points with the true labels on the left, # and with the clustering labels on the right. for i, (c, title) in enumerate(zip( [y, labels], ["True labels", title])): plt.subplot(121 + i); plt.scatter(X[:,0], X[:,1], c=c, s=30, linewidths=0, cmap=plt.cm.rainbow); plt.xticks([]); plt.yticks([]); plt.title(title); # 4. Now, we cluster the dataset with the **K-means** algorithm, a classic and simple clustering algorithm. # In[ ]: km = clu.KMeans() km.fit(X); display_clustering(km.labels_, "KMeans") # 5. This algorithm requires the number of clusters at initialization time. In general, however, we do not necessarily now the number of clusters in the dataset. Here, let's try with `n_clusters=3` (that's cheating, because we happen to know that there are 3 clusters!). # In[ ]: km = clu.KMeans(n_clusters=3) km.fit(X); display_clustering(km.labels_, "KMeans(3)") # 6. Let's try a few other clustering algorithms implemented in scikit-learn. The simplicity of the API makes it really easy to try different methods: it is just a matter of changing the name of the class. # In[ ]: plt.figure(figsize=(8,5)); plt.subplot(231); plt.scatter(X[:,0], X[:,1], c=y, s=30, linewidths=0, cmap=plt.cm.rainbow); plt.xticks([]); plt.yticks([]); plt.title("True labels"); for i, est in enumerate([ clu.SpectralClustering(3), clu.AgglomerativeClustering(3), clu.MeanShift(), clu.AffinityPropagation(), clu.DBSCAN(), ]): est.fit(X); c = relabel(est.labels_) plt.subplot(232 + i); plt.scatter(X[:,0], X[:,1], c=c, s=30, linewidths=0, cmap=plt.cm.rainbow); plt.xticks([]); plt.yticks([]); plt.title(est.__class__.__name__); # The first two algorithms required the number of clusters as input. The next two did not, but they were able to find the right number 3. The last two failed at finding the correct number of clusters (*overclustering*). # > You'll find all the explanations, figures, references, and much more in the book (to be released later this summer). # # > [IPython Cookbook](http://ipython-books.github.io/), by [Cyrille Rossant](http://cyrille.rossant.net), Packt Publishing, 2014 (500 pages).