Title: Basic Clustering with $k$-Means Author: Thomas Breuel Institution: UniKL
from pylab import *
import random as pyrandom
from scipy.spatial.distance import cdist
matplotlib.rc("image",cmap="gray")
from collections import Counter
Consider a collection of points that are sampled from three different densities, in this case normal densities with the same covariances but different means.
data = r_[10*randn(1000,2)+array([70,30]),
10*randn(1000,2)+array([10,10]),
10*randn(1000,2)+array([50,80])]
data = data[pyrandom.sample(xrange(len(data)),len(data))]
Here is a scatterplot of this data.
We clearly see three clusters, corresponding to the three mixture components. How can we recover these clusters?
This is the job of clustering algorithms. One of the most useful clustering algorithms is the k-means algorithm
figsize(10,10)
plot(data[:,0],data[:,1],'b+')
[<matplotlib.lines.Line2D at 0x90d07d0>]
Mixture densities arise in both unsupervised learning and in supervised learning. In both cases, they commonly represent a problem structure in which data is generated from a number of ideal prototypes (the cluster centers) but then corrupted by noise.
problem with normal densities.
clusters with a clustering algorithm and then assign labels to these clusters. This is a form of semi-supervised learning.
a mixture of multiple clusters. That is, each class is generated by multiple prototypes (think characters in different fonts).
You can perform clustering either at the class level or across all samples and then label each cluster with its corresponding class label.
The k-means algorithm is an example of an expectation maximization algorithm. Such algorithms generally have a form in which we need to find some parameters, but in order to find those parameters, we need to know some other values that we can't observe directly.
In this case, we want to find the cluster centers (protos
array), but
in order to compute the cluster center for each cluste, we would have to
know what cluster each data point belongs do, which we don't.
The EM approach is to just guess the result.
protos = array([[30,30],[40,20],[0,90],[50,50]])
start = protos.copy()
figsize(10,10)
plot(data[:,0],data[:,1],'b+')
plot(protos[:,0],protos[:,1],'ro',markersize=10)
[<matplotlib.lines.Line2D at 0x90b3bd0>]
Obviously, those centers are wildly wrong, but let's keep going.
Now, we compute the assignment of the data points to the prototypes
(array closest
).
This is also wildly wrong, but we're going to be using it anyay.
figsize(10,10)
dists = cdist(protos,data)
closest = argmin(dists,axis=0)
for i in range(len(protos)):
plot(data[closest==i,0],data[closest==i,1],['c+','g+','b+','y+'][i])
plot(protos[:,0],protos[:,1],'ro',markersize=10)
[<matplotlib.lines.Line2D at 0x90ea090>]
history = [protos.copy()]
for i in range(len(protos)):
protos[i,:] = average(data[closest==i],axis=0)
history.append(protos.copy())
Now we pretend that the cluster assignments are correct and recompute the location of the centers.
figsize(10,10)
plot(data[:,0],data[:,1],'b+')
harray = array(history)
for i in range(len(protos)):
plot(harray[:,i,0],harray[:,i,1],'r')
plot(harray[0,:,0],harray[0,:,1],'ko',markersize=10)
plot(protos[:,0],protos[:,1],'ro',markersize=10)
[<matplotlib.lines.Line2D at 0x9ac7610>]
As you can see, the centers have moved, and it looks like they have generally moved in the right direction.
Now let's just repeat this process multiple times.
for round in range(1000):
if round%100==0: sys.stderr.write("%d "%round)
dists = cdist(protos,data)
closest = argmin(dists,axis=0)
for i in range(len(protos)):
protos[i,:] = average(data[closest==i],axis=0)
history.append(protos.copy())
0 100 200 300 400 500 600 700 800 900
We can now plot the path that the prototype guesses have taken.
figsize(10,10)
plot(data[:,0],data[:,1],'b+')
history = array(history)
for i in range(len(protos)):
plot(history[:,i,0],history[:,i,1],'r')
plot(history[0,:,0],history[0,:,1],'ko',markersize=10)
plot(protos[:,0],protos[:,1],'ro',markersize=10)
[<matplotlib.lines.Line2D at 0x9e43750>]
As you can see, the final location of the prototype centers (red) are nicely in the center of the classes. The algorithm doesn't give us exactly the cluster centers because there are three clusters but we postulated four cluster centers.
We can also look at the partition of the data induced by these cluster centers.
figsize(10,10)
dists = cdist(protos,data)
closest = argmin(dists,axis=0)
for i in range(len(protos)):
plot(data[closest==i,0],data[closest==i,1],['c+','g+','b+','y+'][i])
plot(protos[:,0],protos[:,1],'ro',markersize=10)
[<matplotlib.lines.Line2D at 0x9e69510>]