# K-Means Clustering¶

So far this semester we have been working with supervised and reinforcement learning algorithms. Another family of machine learning algorithms are unsupervised learning algorithms. These are algorithms designed to find patterns or groupings in a data set. No targets, or desired outputs, are involved.

## Old Faithful Dataset¶

For example, take a look at this data set of eruption durations and the waiting times in between eruptions of the Old Faithful Geyser in Yellowstone National Park.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
!head faithful.csv

"","eruptions","waiting"
"1",3.6,79
"2",1.8,54
"3",3.333,74
"4",2.283,62
"5",4.533,85
"6",2.883,55
"7",4.7,88
"8",3.6,85
"9",1.95,51

In [3]:
datadf = pd.read_csv('faithful.csv', usecols=(1, 2))
datadf

Out[3]:
eruptions waiting
0 3.600 79
1 1.800 54
2 3.333 74
3 2.283 62
4 4.533 85
... ... ...
267 4.117 81
268 2.150 46
269 4.417 90
270 1.817 46
271 4.467 74

272 rows × 2 columns

In [4]:
data = datadf.values
data

Out[4]:
array([[ 3.6  , 79.   ],
[ 1.8  , 54.   ],
[ 3.333, 74.   ],
[ 2.283, 62.   ],
[ 4.533, 85.   ],
[ 2.883, 55.   ],
[ 4.7  , 88.   ],
[ 3.6  , 85.   ],
[ 1.95 , 51.   ],
[ 4.35 , 85.   ],
[ 1.833, 54.   ],
[ 3.917, 84.   ],
[ 4.2  , 78.   ],
[ 1.75 , 47.   ],
[ 4.7  , 83.   ],
[ 2.167, 52.   ],
[ 1.75 , 62.   ],
[ 4.8  , 84.   ],
[ 1.6  , 52.   ],
[ 4.25 , 79.   ],
[ 1.8  , 51.   ],
[ 1.75 , 47.   ],
[ 3.45 , 78.   ],
[ 3.067, 69.   ],
[ 4.533, 74.   ],
[ 3.6  , 83.   ],
[ 1.967, 55.   ],
[ 4.083, 76.   ],
[ 3.85 , 78.   ],
[ 4.433, 79.   ],
[ 4.3  , 73.   ],
[ 4.467, 77.   ],
[ 3.367, 66.   ],
[ 4.033, 80.   ],
[ 3.833, 74.   ],
[ 2.017, 52.   ],
[ 1.867, 48.   ],
[ 4.833, 80.   ],
[ 1.833, 59.   ],
[ 4.783, 90.   ],
[ 4.35 , 80.   ],
[ 1.883, 58.   ],
[ 4.567, 84.   ],
[ 1.75 , 58.   ],
[ 4.533, 73.   ],
[ 3.317, 83.   ],
[ 3.833, 64.   ],
[ 2.1  , 53.   ],
[ 4.633, 82.   ],
[ 2.   , 59.   ],
[ 4.8  , 75.   ],
[ 4.716, 90.   ],
[ 1.833, 54.   ],
[ 4.833, 80.   ],
[ 1.733, 54.   ],
[ 4.883, 83.   ],
[ 3.717, 71.   ],
[ 1.667, 64.   ],
[ 4.567, 77.   ],
[ 4.317, 81.   ],
[ 2.233, 59.   ],
[ 4.5  , 84.   ],
[ 1.75 , 48.   ],
[ 4.8  , 82.   ],
[ 1.817, 60.   ],
[ 4.4  , 92.   ],
[ 4.167, 78.   ],
[ 4.7  , 78.   ],
[ 2.067, 65.   ],
[ 4.7  , 73.   ],
[ 4.033, 82.   ],
[ 1.967, 56.   ],
[ 4.5  , 79.   ],
[ 4.   , 71.   ],
[ 1.983, 62.   ],
[ 5.067, 76.   ],
[ 2.017, 60.   ],
[ 4.567, 78.   ],
[ 3.883, 76.   ],
[ 3.6  , 83.   ],
[ 4.133, 75.   ],
[ 4.333, 82.   ],
[ 4.1  , 70.   ],
[ 2.633, 65.   ],
[ 4.067, 73.   ],
[ 4.933, 88.   ],
[ 3.95 , 76.   ],
[ 4.517, 80.   ],
[ 2.167, 48.   ],
[ 4.   , 86.   ],
[ 2.2  , 60.   ],
[ 4.333, 90.   ],
[ 1.867, 50.   ],
[ 4.817, 78.   ],
[ 1.833, 63.   ],
[ 4.3  , 72.   ],
[ 4.667, 84.   ],
[ 3.75 , 75.   ],
[ 1.867, 51.   ],
[ 4.9  , 82.   ],
[ 2.483, 62.   ],
[ 4.367, 88.   ],
[ 2.1  , 49.   ],
[ 4.5  , 83.   ],
[ 4.05 , 81.   ],
[ 1.867, 47.   ],
[ 4.7  , 84.   ],
[ 1.783, 52.   ],
[ 4.85 , 86.   ],
[ 3.683, 81.   ],
[ 4.733, 75.   ],
[ 2.3  , 59.   ],
[ 4.9  , 89.   ],
[ 4.417, 79.   ],
[ 1.7  , 59.   ],
[ 4.633, 81.   ],
[ 2.317, 50.   ],
[ 4.6  , 85.   ],
[ 1.817, 59.   ],
[ 4.417, 87.   ],
[ 2.617, 53.   ],
[ 4.067, 69.   ],
[ 4.25 , 77.   ],
[ 1.967, 56.   ],
[ 4.6  , 88.   ],
[ 3.767, 81.   ],
[ 1.917, 45.   ],
[ 4.5  , 82.   ],
[ 2.267, 55.   ],
[ 4.65 , 90.   ],
[ 1.867, 45.   ],
[ 4.167, 83.   ],
[ 2.8  , 56.   ],
[ 4.333, 89.   ],
[ 1.833, 46.   ],
[ 4.383, 82.   ],
[ 1.883, 51.   ],
[ 4.933, 86.   ],
[ 2.033, 53.   ],
[ 3.733, 79.   ],
[ 4.233, 81.   ],
[ 2.233, 60.   ],
[ 4.533, 82.   ],
[ 4.817, 77.   ],
[ 4.333, 76.   ],
[ 1.983, 59.   ],
[ 4.633, 80.   ],
[ 2.017, 49.   ],
[ 5.1  , 96.   ],
[ 1.8  , 53.   ],
[ 5.033, 77.   ],
[ 4.   , 77.   ],
[ 2.4  , 65.   ],
[ 4.6  , 81.   ],
[ 3.567, 71.   ],
[ 4.   , 70.   ],
[ 4.5  , 81.   ],
[ 4.083, 93.   ],
[ 1.8  , 53.   ],
[ 3.967, 89.   ],
[ 2.2  , 45.   ],
[ 4.15 , 86.   ],
[ 2.   , 58.   ],
[ 3.833, 78.   ],
[ 3.5  , 66.   ],
[ 4.583, 76.   ],
[ 2.367, 63.   ],
[ 5.   , 88.   ],
[ 1.933, 52.   ],
[ 4.617, 93.   ],
[ 1.917, 49.   ],
[ 2.083, 57.   ],
[ 4.583, 77.   ],
[ 3.333, 68.   ],
[ 4.167, 81.   ],
[ 4.333, 81.   ],
[ 4.5  , 73.   ],
[ 2.417, 50.   ],
[ 4.   , 85.   ],
[ 4.167, 74.   ],
[ 1.883, 55.   ],
[ 4.583, 77.   ],
[ 4.25 , 83.   ],
[ 3.767, 83.   ],
[ 2.033, 51.   ],
[ 4.433, 78.   ],
[ 4.083, 84.   ],
[ 1.833, 46.   ],
[ 4.417, 83.   ],
[ 2.183, 55.   ],
[ 4.8  , 81.   ],
[ 1.833, 57.   ],
[ 4.8  , 76.   ],
[ 4.1  , 84.   ],
[ 3.966, 77.   ],
[ 4.233, 81.   ],
[ 3.5  , 87.   ],
[ 4.366, 77.   ],
[ 2.25 , 51.   ],
[ 4.667, 78.   ],
[ 2.1  , 60.   ],
[ 4.35 , 82.   ],
[ 4.133, 91.   ],
[ 1.867, 53.   ],
[ 4.6  , 78.   ],
[ 1.783, 46.   ],
[ 4.367, 77.   ],
[ 3.85 , 84.   ],
[ 1.933, 49.   ],
[ 4.5  , 83.   ],
[ 2.383, 71.   ],
[ 4.7  , 80.   ],
[ 1.867, 49.   ],
[ 3.833, 75.   ],
[ 3.417, 64.   ],
[ 4.233, 76.   ],
[ 2.4  , 53.   ],
[ 4.8  , 94.   ],
[ 2.   , 55.   ],
[ 4.15 , 76.   ],
[ 1.867, 50.   ],
[ 4.267, 82.   ],
[ 1.75 , 54.   ],
[ 4.483, 75.   ],
[ 4.   , 78.   ],
[ 4.117, 79.   ],
[ 4.083, 78.   ],
[ 4.267, 78.   ],
[ 3.917, 70.   ],
[ 4.55 , 79.   ],
[ 4.083, 70.   ],
[ 2.417, 54.   ],
[ 4.183, 86.   ],
[ 2.217, 50.   ],
[ 4.45 , 90.   ],
[ 1.883, 54.   ],
[ 1.85 , 54.   ],
[ 4.283, 77.   ],
[ 3.95 , 79.   ],
[ 2.333, 64.   ],
[ 4.15 , 75.   ],
[ 2.35 , 47.   ],
[ 4.933, 86.   ],
[ 2.9  , 63.   ],
[ 4.583, 85.   ],
[ 3.833, 82.   ],
[ 2.083, 57.   ],
[ 4.367, 82.   ],
[ 2.133, 67.   ],
[ 4.35 , 74.   ],
[ 2.2  , 54.   ],
[ 4.45 , 83.   ],
[ 3.567, 73.   ],
[ 4.5  , 73.   ],
[ 4.15 , 88.   ],
[ 3.817, 80.   ],
[ 3.917, 71.   ],
[ 4.45 , 83.   ],
[ 2.   , 56.   ],
[ 4.283, 79.   ],
[ 4.767, 78.   ],
[ 4.533, 84.   ],
[ 1.85 , 58.   ],
[ 4.25 , 83.   ],
[ 1.983, 43.   ],
[ 2.25 , 60.   ],
[ 4.75 , 75.   ],
[ 4.117, 81.   ],
[ 2.15 , 46.   ],
[ 4.417, 90.   ],
[ 1.817, 46.   ],
[ 4.467, 74.   ]])
In [5]:
plt.plot(data[:, 0], data[:, 1], '.')
plt.xlabel('duration')
plt.ylabel('interval')

Out[5]:
Text(0, 0.5, 'interval')

We can clearly see two clusters here. For higher dimensional data, we cannot directly visualize the data to see the clusters. We need a mathematical way to detect clusters. This gives rise to the class of unsupervised learning methods called clustering algorithms.

A simple example of a clustering algorithm is the k-means algorithm. It results in identifying $k$ cluster centers. It is an iterative algorithm that starts with an initial assignment of $k$ centers. Then it proceeds by determining which centers each data sample is closest to and adjusts the centers to be the means of each of these data partitions. It then repeats.

Let's develop this algorithm one step at a time.

Each sample is the Old Faithful data has 2 attributes, so each sample is in 2-dimensional space. We know by looking at the above plot that our data nicely falls in two clusters, so we will start with $k=2$. We will initialize the two cluster centers by randomly choosing two of the data samples.

In [6]:
n_samples = data.shape[0]
np.random.choice(range(n_samples), 2, replace=False)

Out[6]:
array([226,  66])
In [7]:
centers = data[np.random.choice(range(n_samples), 2, replace=False), :]
centers

Out[7]:
array([[ 2.2, 60. ],
[ 4.7, 83. ]])

Now we must find all samples that are closest to the first center, and those that are closest to the second sample.

Check out the wonders of numpy broadcasting.

In [8]:
a = np.array([1, 2, 3])
b = np.array([10, 20, 30])
a, b

Out[8]:
(array([1, 2, 3]), array([10, 20, 30]))
In [9]:
a - b

Out[9]:
array([ -9, -18, -27])

But what if we want to subtract every element of a with every element of b?

In [10]:
np.resize(a, (3, 3))

Out[10]:
array([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
In [11]:
np.resize(b, (3, 3))

Out[11]:
array([[10, 20, 30],
[10, 20, 30],
[10, 20, 30]])
In [12]:
np.resize(a, (3, 3)).T

Out[12]:
array([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
In [13]:
np.resize(a, (3, 3)).T - np.resize(b, (3, 3))

Out[13]:
array([[ -9, -19, -29],
[ -8, -18, -28],
[ -7, -17, -27]])

However, we can ask numpy to do this duplication for us if we reshape a to be a column vector and leave b as a row vector.

$$\begin{pmatrix} 1\\ 2\\ 3 \end{pmatrix} - \begin{pmatrix} 10 & 20 & 30 \end{pmatrix} \;\; = \;\; \begin{pmatrix} 1 & 1 & 1\\ 2 & 2 & 2\\ 3 & 3 & 3 \end{pmatrix} - \begin{pmatrix} 10 & 20 & 30\\ 10 & 20 & 30\\ 10 & 20 & 30 \end{pmatrix}$$
In [14]:
a[:, np.newaxis]

Out[14]:
array([[1],
[2],
[3]])
In [15]:
a[:, np.newaxis] - b

Out[15]:
array([[ -9, -19, -29],
[ -8, -18, -28],
[ -7, -17, -27]])

Now imagine that a is a cluster center and b contains data samples, one per row. The first step of calculating the distance from a to all samples in b is to subtract them component-wise.

In [16]:
a = np.array([1, 2, 3])
b = np.array([[10, 20, 30], [40, 50, 60]])
print(a)
print(b)

[1 2 3]
[[10 20 30]
[40 50 60]]

In [17]:
b - a

Out[17]:
array([[ 9, 18, 27],
[39, 48, 57]])

The single row vector a is duplicated for as many rows as there are in b! We can use this to calculate the squared distance between a center and every sample.

In [18]:
centers[0,:]

Out[18]:
array([ 2.2, 60. ])
In [19]:
sqdists_to_center_0 = np.sum((centers[0, :] - data)**2, axis=1)
sqdists_to_center_0

Out[19]:
array([3.62960000e+02, 3.61600000e+01, 1.97283689e+02, 4.00688900e+00,
6.30442889e+02, 2.54664890e+01, 7.90250000e+02, 6.26960000e+02,
8.10625000e+01, 6.29622500e+02, 3.61346890e+01, 5.78948089e+02,
3.28000000e+02, 1.69202500e+02, 5.35250000e+02, 6.40010890e+01,
4.20250000e+00, 5.82760000e+02, 6.43600000e+01, 3.65202500e+02,
8.11600000e+01, 1.69202500e+02, 3.25562500e+02, 8.17516890e+01,
2.01442889e+02, 5.30960000e+02, 2.50542890e+01, 2.59545689e+02,
3.26722500e+02, 3.65986289e+02, 1.73410000e+02, 2.94139289e+02,
3.73618890e+01, 4.03359889e+02, 1.98666689e+02, 6.40334890e+01,
1.44110889e+02, 4.06932689e+02, 1.13468900e+00, 9.06671889e+02,
4.04622500e+02, 4.10048900e+00, 5.81602689e+02, 4.20250000e+00,
1.74442889e+02, 5.30247689e+02, 1.86666890e+01, 4.90100000e+01,
4.89919489e+02, 1.04000000e+00, 2.31760000e+02, 9.06330256e+02,
3.61346890e+01, 4.06932689e+02, 3.62180890e+01, 5.36198489e+02,
1.23301289e+02, 1.62840890e+01, 2.94602689e+02, 4.45481689e+02,
1.00108900e+00, 5.81290000e+02, 1.44202500e+02, 4.90760000e+02,
1.46689000e-01, 1.02884000e+03, 3.27869089e+02, 3.30250000e+02,
2.50176890e+01, 1.75250000e+02, 4.87359889e+02, 1.60542890e+01,
3.66290000e+02, 1.24240000e+02, 4.04708900e+00, 2.64219689e+02,
3.34890000e-02, 3.29602689e+02, 2.58832489e+02, 5.30960000e+02,
2.28736489e+02, 4.88549689e+02, 1.03610000e+02, 2.51874890e+01,
1.72485689e+02, 7.91469289e+02, 2.59062500e+02, 4.05368489e+02,
1.44001089e+02, 6.79240000e+02, 0.00000000e+00, 9.04549689e+02,
1.00110889e+02, 3.30848689e+02, 9.13468900e+00, 1.48410000e+02,
5.82086089e+02, 2.27402500e+02, 8.11108890e+01, 4.91290000e+02,
4.08008900e+00, 7.88695889e+02, 1.21010000e+02, 5.34290000e+02,
4.44422500e+02, 1.69110889e+02, 5.82250000e+02, 6.41738890e+01,
6.83022500e+02, 4.43199289e+02, 2.31416089e+02, 1.01000000e+00,
8.48290000e+02, 3.65915089e+02, 1.25000000e+00, 4.46919489e+02,
1.00013689e+02, 6.30760000e+02, 1.14668900e+00, 7.33915089e+02,
4.91738890e+01, 8.44856890e+01, 2.93202500e+02, 1.60542890e+01,
7.89760000e+02, 4.43455489e+02, 2.25080089e+02, 4.89290000e+02,
2.50044890e+01, 9.06002500e+02, 2.25110889e+02, 5.32869089e+02,
1.63600000e+01, 8.45549689e+02, 1.96134689e+02, 4.88765489e+02,
8.11004890e+01, 6.83469289e+02, 4.90278890e+01, 3.63350089e+02,
4.45133089e+02, 1.08900000e-03, 4.89442889e+02, 2.95848689e+02,
2.60549689e+02, 1.04708900e+00, 4.05919489e+02, 1.21033489e+02,
1.30441000e+03, 4.91600000e+01, 2.97025889e+02, 2.92240000e+02,
2.50400000e+01, 4.46760000e+02, 1.22868689e+02, 1.03240000e+02,
4.46290000e+02, 1.09254569e+03, 4.91600000e+01, 8.44122289e+02,
2.25000000e+02, 6.79802500e+02, 4.04000000e+00, 3.26666689e+02,
3.76900000e+01, 2.61678689e+02, 9.02788900e+00, 7.91840000e+02,
6.40712890e+01, 1.09484189e+03, 1.21080089e+02, 9.01368900e+00,
2.94678689e+02, 6.52836890e+01, 4.44869089e+02, 4.45549689e+02,
1.74290000e+02, 1.00047089e+02, 6.28240000e+02, 1.99869089e+02,
2.51004890e+01, 2.94678689e+02, 5.33202500e+02, 5.31455489e+02,
8.10278890e+01, 3.28986289e+02, 5.79545689e+02, 1.96134689e+02,
5.33915089e+02, 2.50002890e+01, 4.47760000e+02, 9.13468900e+00,
2.62760000e+02, 5.79610000e+02, 2.92118756e+02, 4.45133089e+02,
7.30690000e+02, 2.93691556e+02, 8.10025000e+01, 3.30086089e+02,
1.00000000e-02, 4.88622500e+02, 9.64736489e+02, 4.91108890e+01,
3.29760000e+02, 1.96173889e+02, 2.93695889e+02, 5.78722500e+02,
1.21071289e+02, 5.34290000e+02, 1.21033489e+02, 4.06250000e+02,
1.21110889e+02, 2.27666689e+02, 1.74810890e+01, 2.60133089e+02,
4.90400000e+01, 1.16276000e+03, 2.50400000e+01, 2.59802500e+02,
1.00110889e+02, 4.88272489e+02, 3.62025000e+01, 2.30212089e+02,
3.27240000e+02, 3.64674889e+02, 3.27545689e+02, 3.28272489e+02,
1.02948089e+02, 3.66522500e+02, 1.03545689e+02, 3.60470890e+01,
6.79932289e+02, 1.00000289e+02, 9.05062500e+02, 3.61004890e+01,
3.61225000e+01, 2.93338889e+02, 3.64062500e+02, 1.60176890e+01,
2.28802500e+02, 1.69022500e+02, 6.83469289e+02, 9.49000000e+00,
6.30678689e+02, 4.86666689e+02, 9.01368900e+00, 4.88695889e+02,
4.90044890e+01, 2.00622500e+02, 3.60000000e+01, 5.34062500e+02,
1.70868689e+02, 1.74290000e+02, 7.87802500e+02, 4.02614689e+02,
1.23948089e+02, 5.34062500e+02, 1.60400000e+01, 3.65338889e+02,
3.30589489e+02, 5.81442889e+02, 4.12250000e+00, 5.33202500e+02,
2.89047089e+02, 2.50000000e-03, 2.31502500e+02, 4.44674889e+02,
1.96002500e+02, 9.04915089e+02, 1.96146689e+02, 2.01139289e+02])
In [20]:
sqdists_to_center_1 = np.sum((centers[1, :] - data)**2, axis=1)
sqdists_to_center_1

Out[20]:
array([1.72100000e+01, 8.49410000e+02, 8.28686890e+01, 4.46841889e+02,
4.02788900e+00, 7.87301489e+02, 2.50000000e+01, 5.21000000e+00,
1.03156250e+03, 4.12250000e+00, 8.49219689e+02, 1.61308900e+00,
2.52500000e+01, 1.30470250e+03, 0.00000000e+00, 9.67416089e+02,
4.49702500e+02, 1.01000000e+00, 9.70610000e+02, 1.62025000e+01,
1.03241000e+03, 1.30470250e+03, 2.65625000e+01, 1.98666689e+02,
8.10278890e+01, 1.21000000e+00, 7.91469289e+02, 4.93806890e+01,
2.57225000e+01, 1.60712890e+01, 1.00160000e+02, 3.60542890e+01,
2.90776889e+02, 9.44488900e+00, 8.17516890e+01, 9.68198489e+02,
1.23302589e+03, 9.01768900e+00, 5.84219689e+02, 4.90068890e+01,
9.12250000e+00, 6.32935489e+02, 1.01768900e+00, 6.33702500e+02,
1.00027889e+02, 1.91268900e+00, 3.61751689e+02, 9.06760000e+02,
1.00448900e+00, 5.83290000e+02, 6.40100000e+01, 4.90002560e+01,
8.49219689e+02, 9.01768900e+00, 8.49803089e+02, 3.34890000e-02,
1.44966289e+02, 3.70199089e+02, 3.60176890e+01, 4.14668900e+00,
5.82086089e+02, 1.04000000e+00, 1.23370250e+03, 1.01000000e+00,
5.37311689e+02, 8.10900000e+01, 2.52840890e+01, 2.50000000e+01,
3.30932689e+02, 1.00000000e+02, 1.44488900e+00, 7.36469289e+02,
1.60400000e+01, 1.44490000e+02, 4.48382089e+02, 4.91346890e+01,
5.36198489e+02, 2.50176890e+01, 4.96674890e+01, 1.21000000e+00,
6.43214890e+01, 1.13468900e+00, 1.69360000e+02, 3.28272489e+02,
1.00400689e+02, 2.50542890e+01, 4.95625000e+01, 9.03348900e+00,
1.23141609e+03, 9.49000000e+00, 5.35250000e+02, 4.91346890e+01,
1.09702589e+03, 2.50136890e+01, 4.08219689e+02, 1.21160000e+02,
1.00108900e+00, 6.49025000e+01, 1.03202589e+03, 1.04000000e+00,
4.45915089e+02, 2.51108890e+01, 1.16276000e+03, 4.00000000e-02,
4.42250000e+00, 1.30402589e+03, 1.00000000e+00, 9.69508889e+02,
9.02250000e+00, 5.03428900e+00, 6.40010890e+01, 5.81760000e+02,
3.60400000e+01, 1.60800890e+01, 5.85000000e+02, 4.00448900e+00,
1.09467869e+03, 4.01000000e+00, 5.84311689e+02, 1.60800890e+01,
9.04338889e+02, 1.96400689e+02, 3.62025000e+01, 7.36469289e+02,
2.50100000e+01, 4.87048900e+00, 1.45174509e+03, 1.04000000e+00,
7.89919489e+02, 4.90025000e+01, 1.45202589e+03, 2.84089000e-01,
7.32610000e+02, 3.61346890e+01, 1.37721969e+03, 1.10048900e+00,
1.03193549e+03, 9.05428900e+00, 9.07112889e+02, 1.69350890e+01,
4.21808900e+00, 5.35086089e+02, 1.02788900e+00, 3.60136890e+01,
4.91346890e+01, 5.83382089e+02, 9.00448900e+00, 1.16319849e+03,
1.69160000e+02, 9.08410000e+02, 3.61108890e+01, 3.64900000e+01,
3.29290000e+02, 4.01000000e+00, 1.45283689e+02, 1.69490000e+02,
4.04000000e+00, 1.00380689e+02, 9.08410000e+02, 3.65372890e+01,
1.45025000e+03, 9.30250000e+00, 6.32290000e+02, 2.57516890e+01,
2.90440000e+02, 4.90136890e+01, 4.05442889e+02, 2.50900000e+01,
9.68656289e+02, 1.00006889e+02, 1.16374509e+03, 6.82848689e+02,
3.60136890e+01, 2.26868689e+02, 4.28408900e+00, 4.13468900e+00,
1.00040000e+02, 1.09421209e+03, 4.49000000e+00, 8.12840890e+01,
7.91935489e+02, 3.60136890e+01, 2.02500000e-01, 8.70489000e-01,
1.03111289e+03, 2.50712890e+01, 1.38068900e+00, 1.37721969e+03,
8.00890000e-02, 7.90335289e+02, 4.01000000e+00, 6.84219689e+02,
4.90100000e+01, 1.36000000e+00, 3.65387560e+01, 4.21808900e+00,
1.74400000e+01, 3.61115560e+01, 1.03000250e+03, 2.50010890e+01,
5.35760000e+02, 1.12250000e+00, 6.43214890e+01, 9.08025889e+02,
2.50100000e+01, 1.37750889e+03, 3.61108890e+01, 1.72250000e+00,
1.16365629e+03, 4.00000000e-02, 1.49368489e+02, 9.00000000e+00,
1.16402589e+03, 6.47516890e+01, 3.62646089e+02, 4.92180890e+01,
9.05290000e+02, 1.21010000e+02, 7.91290000e+02, 4.93025000e+01,
1.09702589e+03, 1.18748900e+00, 8.49702500e+02, 6.40470890e+01,
2.54900000e+01, 1.63398890e+01, 2.53806890e+01, 2.51874890e+01,
1.69613089e+02, 1.60225000e+01, 1.69380689e+02, 8.46212089e+02,
9.26728900e+00, 1.09516529e+03, 4.90625000e+01, 8.48935489e+02,
8.49122500e+02, 3.61738890e+01, 1.65625000e+01, 3.66602689e+02,
6.43025000e+01, 1.30152250e+03, 9.05428900e+00, 4.03240000e+02,
4.01368900e+00, 1.75168900e+00, 6.82848689e+02, 1.11088900e+00,
2.62589489e+02, 8.11225000e+01, 8.47250000e+02, 6.25000000e-02,
1.01283689e+02, 1.00040000e+02, 2.53025000e+01, 9.77968900e+00,
1.44613089e+02, 6.25000000e-02, 7.36290000e+02, 1.61738890e+01,
2.50044890e+01, 1.02788900e+00, 6.33122500e+02, 2.02500000e-01,
1.60738209e+03, 5.35002500e+02, 6.40025000e+01, 4.33988900e+00,
1.37550250e+03, 4.90800890e+01, 1.37731169e+03, 8.10542890e+01])

And, which samples are closest to the first center?

In [21]:
sqdists_to_center_0 < sqdists_to_center_1

Out[21]:
array([False,  True, False,  True, False,  True, False, False,  True,
False,  True, False, False,  True, False,  True,  True, False,
True, False,  True,  True, False,  True, False, False,  True,
False, False, False, False, False,  True, False, False,  True,
True, False,  True, False, False,  True, False,  True, False,
False,  True,  True, False,  True, False, False,  True, False,
True, False,  True,  True, False, False,  True, False,  True,
False,  True, False, False, False,  True, False, False,  True,
False,  True,  True, False,  True, False, False, False, False,
False,  True,  True, False, False, False, False,  True, False,
True, False,  True, False,  True, False, False, False,  True,
False,  True, False,  True, False, False,  True, False,  True,
False, False, False,  True, False, False,  True, False,  True,
False,  True, False,  True,  True, False,  True, False, False,
True, False,  True, False,  True, False,  True, False,  True,
False,  True, False,  True, False, False,  True, False, False,
False,  True, False,  True, False,  True, False, False,  True,
False,  True,  True, False, False,  True, False,  True, False,
True, False,  True, False,  True, False,  True, False,  True,
True, False,  True, False, False, False,  True, False, False,
True, False, False, False,  True, False, False,  True, False,
True, False,  True, False, False, False, False, False, False,
True, False,  True, False, False,  True, False,  True, False,
False,  True, False,  True, False,  True, False,  True, False,
True, False,  True, False,  True, False,  True, False, False,
False, False, False,  True, False,  True,  True, False,  True,
False,  True,  True, False, False,  True, False,  True, False,
True, False, False,  True, False,  True, False,  True, False,
False, False, False, False,  True, False,  True, False, False,
False,  True, False,  True,  True, False, False,  True, False,
True, False])

This approach is easy for $k=2$, but what if $k$ is larger. Can we calculate all of the needed distances in one numpy expression? I bet we can!

In [22]:
centers[:,np.newaxis,:].shape, data.shape

Out[22]:
((2, 1, 2), (272, 2))
In [23]:
(centers[:,np.newaxis,:] - data).shape

Out[23]:
(2, 272, 2)
In [24]:
np.sum((centers[:,np.newaxis,:] - data)**2, axis=2).shape

Out[24]:
(2, 272)

These are the square distances between each of our two centers and each of the 272 samples. If we take the argmin across the two rows, we will have the index of the closest center for each of the 272 samples.

In [25]:
clusters = np.argmin(np.sum((centers[:,np.newaxis,:] - data)**2, axis=2), axis=0)
clusters

Out[25]:
array([1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0,
1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0,
1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1,
1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,
0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1,
1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1,
0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1,
1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1,
0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0,
1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1,
0, 0, 1, 1, 0, 1, 0, 1])

Now, to calculate the new values of our two centers, we just calculate the mean of the appropriate samples.

In [26]:
data[clusters == 0, :].mean(axis=0)

Out[26]:
array([ 2.26414286, 56.39285714])
In [27]:
data[clusters == 1, :].mean(axis=0)

Out[27]:
array([ 4.34433125, 81.05      ])

Can do both in a for loop.

In [28]:
k = 2
for i in range(k):
centers[i, :] = data[clusters == i, :].mean(axis=0)

In [29]:
centers

Out[29]:
array([[ 2.26414286, 56.39285714],
[ 4.34433125, 81.05      ]])

Now, we can wrap these steps in our first version of a kmeans function.

In [30]:
def kmeans(data, k = 2, n_iterations = 5):

# Initial centers
centers = data[np.random.choice(range(data.shape[0]), k, replace=False), :]

# Repeat n times
for iteration in range(n_iterations):

# Which center is each sample closest to?
closest = np.argmin(np.sum((centers[:, np.newaxis, :] - data)**2, axis=2), axis=0)

# Update cluster centers
for i in range(k):
centers[i, :] = data[closest == i, :].mean(axis=0)

return centers

In [31]:
kmeans(data, 2)

Out[31]:
array([[ 2.09433   , 54.75      ],
[ 4.29793023, 80.28488372]])
In [32]:
kmeans(data, 2)

Out[32]:
array([[ 4.29793023, 80.28488372],
[ 2.09433   , 54.75      ]])

We need a measure of the quality of our clustering. For this, we define $J$, which is a performance measure being minimized by k-means. It is defined as $$J = \sum_{n=1}^N \sum_{k=1}^K r_{nk} ||\mathbf{x}_n - \mathbf{\mu}_k||^2$$ where $N$ is the number of samples, $K$ is the number of cluster centers, $\mathbf{x}_n$ is the $n^{th}$ sample and $\mathbf{\mu}_k$ is the $k^{th}$ center, each being an element of $\mathbf{R}^p$ where $p$ is the dimensionality of the data. $r_{nk}$ is 1 if $\mathbf{x}_n$ is closest to center $\mathbf{\mu}_k$, and 0 otherwise.

The sums can be computed using python for loops, but, as you know, for loops are much slower than matrix operations in python, so let's do the matrix magic. We already know how to calculate the difference between all samples and all centers.

In [33]:
sqdists = np.sum((centers[:,np.newaxis,:] - data)**2, axis=2)
sqdists.shape

Out[33]:
(2, 272)

The calculation of $J$ requires us to multiply the squared differences of the each component by $r_{nk}$. Since we already have all of the squared distances, let's just sum up the minimum distances for each sample.

In [34]:
np.min(sqdists, axis=0)

Out[34]:
array([4.75652901e+00, 5.94119390e+00, 5.07252909e+01, 3.14404066e+01,
1.56380959e+01, 2.32303518e+00, 4.84290003e+01, 1.61565290e+01,
2.91815939e+01, 1.56025321e+01, 5.91164947e+00, 8.88511200e+00,
9.32333151e+00, 8.84901082e+01, 3.92900026e+00, 1.93066306e+01,
3.17043939e+01, 8.91013401e+00, 1.97382796e+01, 4.21139838e+00,
2.92983368e+01, 8.84901082e+01, 1.01023284e+01, 1.46834075e+02,
4.97380959e+01, 4.35652901e+00, 2.02834490e+00, 2.55707940e+01,
9.54686338e+00, 4.21036215e+00, 6.48044653e+01, 1.64175476e+01,
9.35134878e+01, 1.19942715e+00, 4.99639596e+01, 1.93582735e+01,
7.05977735e+01, 1.34129715e+00, 6.98307804e+00, 8.02949303e+01,
1.10253213e+00, 2.72817804e+00, 8.75208137e+00, 2.84725104e+00,
6.48380959e+01, 4.85790950e+00, 6.03299352e+01, 1.15384225e+01,
9.85829647e-01, 6.86696533e+00, 3.68101340e+01, 8.02406377e+01,
5.91164947e+00, 1.34129715e+00, 6.00787804e+00, 4.09266402e+00,
1.01396044e+02, 5.82252020e+01, 1.64520814e+01, 3.24699723e-03,
6.79816376e+00, 8.72673276e+00, 7.07043939e+01, 1.11013401e+00,
1.32114163e+01, 1.19905599e+02, 9.33394637e+00, 9.42900026e+00,
7.41217735e+01, 6.49290003e+01, 9.99427147e-01, 2.42630612e-01,
4.22673276e+00, 1.01121064e+02, 3.15190923e+01, 2.60247501e+01,
1.30725592e+01, 9.35208137e+00, 2.57153265e+01, 4.35652901e+00,
3.66471609e+01, 9.02628397e-01, 1.22162198e+02, 7.42189638e+01,
6.48794126e+01, 4.86490309e+01, 2.56579971e+01, 1.13231450e+00,
7.04494878e+01, 2.46210640e+01, 1.30155939e+01, 8.01026284e+01,
4.10263449e+01, 9.52591575e+00, 4.38402209e+01, 8.19044653e+01,
8.80661512e+00, 3.69557296e+01, 2.92406306e+01, 1.21126776e+00,
3.14879495e+01, 4.83030139e+01, 5.46812796e+01, 3.82673276e+00,
8.91308847e-02, 8.83834878e+01, 8.82900026e+00, 1.95286923e+01,
2.47582009e+01, 4.39859022e-01, 3.67535634e+01, 6.79847961e+00,
6.35112678e+01, 4.20778075e+00, 7.11545104e+00, 8.58296472e-02,
4.08714163e+01, 1.56678665e+01, 6.99713061e+00, 3.54077807e+01,
1.16359878e+01, 1.45279413e+02, 1.64113984e+01, 2.42630612e-01,
4.83678665e+01, 3.35811372e-01, 1.29917702e+02, 9.26732760e-01,
1.94005918e+00, 8.01959334e+01, 1.29954916e+02, 3.83394637e+00,
4.41479612e-01, 6.32026284e+01, 1.08197364e+02, 9.03995272e-01,
2.92281780e+01, 2.48490309e+01, 1.15649066e+01, 4.57622590e+00,
1.48946472e-02, 1.30124495e+01, 9.38095897e-01, 1.66259157e+01,
2.55026284e+01, 6.87623518e+00, 1.18582965e+00, 5.47154163e+01,
2.24073535e+02, 1.17269082e+01, 1.68767646e+01, 1.65210640e+01,
7.41013653e+01, 6.78665097e-02, 1.01606744e+02, 1.22221064e+02,
2.67327597e-02, 1.42870794e+02, 1.17269082e+01, 6.33448789e+01,
1.29801308e+02, 2.45402646e+01, 2.65267961e+00, 9.56395965e+00,
9.38245368e+01, 2.55594628e+01, 4.36649163e+01, 4.87324015e+01,
1.94068495e+01, 1.42876848e+02, 5.47748449e+01, 4.01435184e-01,
1.64594628e+01, 1.35868221e+02, 3.39463722e-02, 2.62839723e-03,
6.48267328e+01, 4.08919878e+01, 1.57210640e+01, 4.97339464e+01,
2.08532090e+00, 1.64594628e+01, 3.81139838e+00, 4.13581137e+00,
2.91363352e+01, 9.31036215e+00, 8.77079402e+00, 1.08197364e+02,
3.80778075e+00, 1.94663518e+00, 2.10134010e-01, 5.54506612e-01,
2.57101340e+01, 8.76219776e+00, 1.65456345e+01, 1.48946472e-02,
3.61153953e+01, 1.64029695e+01, 2.90831082e+01, 9.40661512e+00,
1.30384225e+01, 9.02532135e-01, 9.90471609e+01, 1.16692020e+01,
9.36786651e+00, 1.08242978e+02, 1.64030139e+01, 8.94686338e+00,
5.47639923e+01, 3.82673276e+00, 1.04849320e+02, 1.22900026e+00,
5.48120592e+01, 3.68639596e+01, 5.91977020e+01, 2.55148946e+01,
1.15299368e+01, 1.67910134e+02, 2.00982247e+00, 2.55402646e+01,
4.10263449e+01, 9.08480122e-01, 5.99010818e+00, 3.66217290e+01,
9.42106401e+00, 4.25417950e+00, 9.37079402e+00, 9.30848012e+00,
1.22285112e+02, 4.24479963e+00, 1.22170794e+02, 5.74913061e+00,
2.45285278e+01, 4.08708449e+01, 8.01136659e+01, 5.87103518e+00,
5.89727961e+00, 1.64062615e+01, 4.35799713e+00, 5.78733638e+01,
3.66402646e+01, 8.82331368e+01, 2.48490309e+01, 4.40586510e+01,
1.56594628e+01, 1.16395965e+00, 4.01435184e-01, 9.03013872e-01,
1.12528678e+02, 4.97025321e+01, 5.72987961e+00, 3.81366588e+00,
6.54067439e+01, 6.48267328e+01, 4.83402646e+01, 1.38057825e+00,
1.01185112e+02, 3.81366588e+00, 2.24108184e-01, 4.20626152e+00,
9.48114887e+00, 8.73809590e+00, 2.75442247e+00, 3.81139838e+00,
1.79447664e+02, 1.30116796e+01, 3.67670671e+01, 5.41794972e-02,
1.08024508e+02, 8.01077807e+01, 1.08211416e+02, 4.97175476e+01])
In [35]:
np.sum(np.min(sqdists, axis=0))

Out[35]:
9240.152878462477

Let's define a function named calcJ to do this calculation.

In [36]:
def calcJ(data, centers):
sqdists = np.sum((centers[:,np.newaxis,:] - data)**2, axis=2)
return np.sum(np.min(sqdists, axis=0))

In [37]:
calcJ(data, centers)

Out[37]:
9240.152878462477

Now we can add this calculation to track the value of $J$ for each iteration as a kind of learning curve. $J$ measures the average "spread" within each cluster, so the smaller it is, the better.

In [38]:
def kmeans(data, k = 2, n_iterations = 5):

# Initialize centers and list J to track performance metric
centers = data[np.random.choice(range(data.shape[0]), k, replace=False), :]
J = []

for iteration in range(n_iterations):

# Which center is each sample closest to?
sqdistances = np.sum((centers[:, np.newaxis, :] - data)**2, axis=2)
closest = np.argmin(sqdistances, axis=0)

# Calculate J and append to list J
J.append(calcJ(data, centers))

# Update cluster centers
for i in range(k):
centers[i, :] = data[closest == i,:].mean(axis=0)

# Calculate J one final time and return results
J.append(calcJ(data, centers))
return centers, J, closest

In [39]:
centers, J, closest = kmeans(data, 2)

In [40]:
J

Out[40]:
[53997.652259999995,
12341.837671828205,
9134.49185083075,
8904.34103114802,
8901.76872094721,
8901.76872094721]
In [41]:
plt.plot(J);

In [42]:
centers, J, closest = kmeans(data, 2, 10)
plt.plot(J);

In [43]:
centers

Out[43]:
array([[ 2.09433   , 54.75      ],
[ 4.29793023, 80.28488372]])
In [44]:
closest

Out[44]:
array([1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0,
1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0,
1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1,
1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1,
1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1,
1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1,
1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1,
0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0,
1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1,
0, 0, 1, 1, 0, 1, 0, 1])
In [45]:
centers, J, closest = kmeans(data, 2, 2)

plt.figure(figsize=(15, 8))

plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=80, c=closest, alpha=0.5)
plt.scatter(centers[:, 0], centers[:, 1], s=80, c="green")

plt.subplot(1, 2, 2)
plt.plot(J)

centers

Out[45]:
array([[ 2.09393939, 54.62626263],
[ 4.28541618, 80.20809249]])
In [46]:
centers, J, closest = kmeans(data, 2, 10)

plt.figure(figsize=(15, 8))

plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=80, c=closest, alpha=0.5)
plt.scatter(centers[:, 0], centers[:, 1], s=80, c="green")

plt.subplot(1, 2, 2)
plt.plot(J)

centers

Out[46]:
array([[ 2.09433   , 54.75      ],
[ 4.29793023, 80.28488372]])
In [47]:
centers, J, closest = kmeans(data, 3, 10)

plt.figure(figsize=(15, 8))

plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=80, c=closest, alpha=0.5)
plt.scatter(centers[:, 0], centers[:, 1], s=80, c="green")

plt.subplot(1, 2, 2)
plt.plot(J)

centers

Out[47]:
array([[ 2.00784848, 51.25757576],
[ 2.6673617 , 63.93617021],
[ 4.34461006, 81.10691824]])
In [48]:
centers, J, closest = kmeans(data, 4, 10)

plt.figure(figsize=(15, 8))

plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=80, c=closest, alpha=0.5)
plt.scatter(centers[:, 0], centers[:, 1], s=80, c="green")

plt.subplot(1, 2, 2)
plt.plot(J)

centers

Out[48]:
array([[ 4.3690119 , 84.91666667],
[ 4.2403908 , 75.95402299],
[ 2.26965789, 61.34210526],
[ 2.0082381 , 50.98412698]])
In [49]:
centers, J, closest = kmeans(data, 6, 20)

plt.figure(figsize=(15, 8))

plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=80, c=closest, alpha=0.5)
plt.scatter(centers[:, 0], centers[:, 1], s=80, c="green")

plt.subplot(1, 2, 2)
plt.plot(J)

centers

Out[49]:
array([[ 3.00344828, 66.48275862],
[ 2.02545098, 55.82352941],
[ 1.9745625 , 48.15625   ],
[ 4.29970588, 76.39705882],
[ 4.49343478, 89.91304348],
[ 4.3386087 , 82.68115942]])

## MNIST Dataset¶

So, clustering two-dimensional data is not all that exciting. How about 784-dimensional data, such as our good buddy the MNIST data set?

In [50]:
import gzip
import pickle

with gzip.open('mnist.pkl.gz', 'rb') as f:
train_set, valid_set, test_set = pickle.load(f, encoding='latin1')

Xtrain = train_set[0]
Ttrain = train_set[1].reshape((-1,1))

Xtest = test_set[0]
Ttest = test_set[1].reshape((-1,1))

Xtrain.shape, Ttrain.shape, Xtest.shape, Ttest.shape

Out[50]:
((50000, 784), (50000, 1), (10000, 784), (10000, 1))

How many clusters shall we use?

In [51]:
centers, J, closest = kmeans(Xtrain, k=10, n_iterations=10)

In [52]:
plt.plot(J);

In [53]:
centers.shape

Out[53]:
(10, 784)
In [54]:
for i in range(10):
plt.subplot(2, 5, i + 1)
plt.imshow(-centers[i, :].reshape((28, 28)), cmap='gray')
plt.axis('off')

In [55]:
centers, J, closest = kmeans(Xtrain, k=10, n_iterations=20)
plt.plot(J)
plt.figure()
for i in range(10):
plt.subplot(2, 5, i + 1)
plt.imshow(-centers[i, :].reshape((28, 28)), cmap='gray')
plt.axis('off')

In [56]:
centers, J, closest = kmeans(Xtrain, k=20, n_iterations=20)
plt.plot(J)
plt.figure()
for i in range(20):
plt.subplot(4, 5, i + 1)
plt.imshow(-centers[i, :].reshape((28, 28)), interpolation='nearest', cmap='gray')
plt.axis('off')

In [57]:
centers, J, closest = kmeans(Xtrain, k=20, n_iterations=20)
plt.plot(J)
plt.figure()
for i in range(20):
plt.subplot(4, 5, i + 1)
plt.imshow(-centers[i, :].reshape((28, 28)), interpolation='nearest', cmap='gray')
plt.axis('off')

In [58]:
centers, J, closest = kmeans(Xtrain, k=40, n_iterations=20)
plt.plot(J)
plt.figure()
for i in range(40):
plt.subplot(4, 10, i + 1)
plt.imshow(-centers[i, :].reshape((28, 28)), interpolation='nearest', cmap='gray')
plt.axis('off')


How could you use the results of the kmeans clustering algorithm as the first step in a classification algorithm?

# K-Nearest-Neighbor Classification¶

Now that we have some experience in calculating distances between samples, we are a short step away from an implementation of a common classification algorithm called k-nearest-neighbor. This is a non-parametric algorithm, meaning that it does not involve parameters, like weights, to make its decisions. Instead, we could call it a memory-based method. The algorithm classifies a sample by determining the $k$ closest samples in the training set and returns the most common class label among those $k$ nearest samples.

Training is terribly simple. We just have to store the training samples. Classification is also trivial to code. We just calculate squared distances between training samples and the samples being classified and return the most common class label among the $k$ closest training samples.

Let's create a class named KNN to implement this algorithm.

First, let's practice our numpy-foo to see how to pick the most common class, with a minimum amount of code.

Remember that sqdists from above is n_centers x n_samples.

Let's pretend we wish to classify the first three MNIST test samples.

In [59]:
sqdists = np.sum((Xtest[:3,np.newaxis,:] - Xtrain)**2, axis=2)
sqdists.shape

Out[59]:
(3, 50000)

Okay. Now all we have to do is find the $k$ smallest distances in each row. Let's use $k=5$.

In [60]:
k = 5
np.sort(sqdists[0, :])[:k]

Out[60]:
array([ 9.6193695, 11.355759 , 11.403915 , 12.214478 , 12.627594 ],
dtype=float32)

But, we need the indices of these values so we can look up their class labels in T.

In [61]:
k = 5
np.argsort(sqdists[0, :])[:k]

Out[61]:
array([38620, 16186, 27059, 47003, 14563])

Now we have to do this for each row in sqdists. Or do we? Wouldn't it be nice if np.argsort sorts each row independently so we can do this in one function call?

In [62]:
np.sort(sqdists, axis=1)

Out[62]:
array([[  9.6193695,  11.355759 ,  11.403915 , ..., 211.29517  ,
211.7034   , 235.06609  ],
[ 20.636139 ,  22.408554 ,  25.232117 , ..., 215.19962  ,
216.1485   , 221.52444  ],
[  1.6865845,   1.7748108,   2.063202 , ..., 227.82173  ,
245.48021  , 251.6908   ]], dtype=float32)

Yippee!

In [63]:
np.argsort(sqdists, axis=1)

Out[63]:
array([[38620, 16186, 27059, ..., 10259, 25321, 41358],
[28882, 49160, 24612, ..., 43452, 10237, 13650],
[46512, 15224, 47333, ..., 25321, 25285, 41358]])
In [64]:
indices = np.argsort(sqdists, axis=1)
indices

Out[64]:
array([[38620, 16186, 27059, ..., 10259, 25321, 41358],
[28882, 49160, 24612, ..., 43452, 10237, 13650],
[46512, 15224, 47333, ..., 25321, 25285, 41358]])
In [65]:
Ttrain[indices, :]

Out[65]:
array([[[7],
[7],
[7],
...,
[0],
[8],
[0]],

[[2],
[2],
[2],
...,
[4],
[0],
[4]],

[[1],
[1],
[1],
...,
[8],
[0],
[0]]])
In [66]:
Ttrain[indices, :][:, :, 0]

Out[66]:
array([[7, 7, 7, ..., 0, 8, 0],
[2, 2, 2, ..., 4, 0, 4],
[1, 1, 1, ..., 8, 0, 0]])

Cool! Now we just have to take the first $k$ columns of these and determine the most common label across the columns, for each row. We can use scipy.stats.mode for this!

In [67]:
import scipy.stats as ss
ss.mode([1, 2, 3, 4, 2, 2, 2])

Out[67]:
ModeResult(mode=array([2]), count=array([4]))
In [68]:
ss.mode([1, 2, 3, 4, 2, 2, 2])[0]

Out[68]:
array([2])
In [69]:
ss.mode([1, 2, 3, 4, 2, 2, 2])[0]

Out[69]:
array([2])
In [70]:
ss.mode(Ttrain[indices, :][:, :, 0], axis=1)[0]

Out[70]:
array([[1],
[1],
[1]])
In [71]:
Ttest[:3]

Out[71]:
array([[7],
[2],
[1]])

Well, maybe we will do better with different values of $k$.

Finally, we can now define our KNN class.

In [72]:
import numpy as np
import scipy.stats as ss  # for ss.mode

class KNN():

def __init__(self):

self.X = None  # data will be stored here
self.T = None  # class labels will be stored here
self.Xmeans = None
self.Xstds = None

def train(self, X, T):

if self.Xmeans is None:
self.Xmeans = X.mean(axis=0)
self.Xstds = X.std(axis=0)
self.Xstds[self.Xstds == 0] = 1

self.X = self._standardizeX(X)
self.T = T

def _standardizeX(self, X):
return (X - self.Xmeans) / self.Xstds

def use(self, Xnew, k = 1):
self.k = k
# Calc squared distance from all samples in Xnew with all stored in self.X
sqdists = np.sum( (self._standardizeX(Xnew)[:,np.newaxis,:] - self.X)**2, axis=-1 )
# sqdist

# sqdists is now n_new_samples x n_train_samples
# Sort each row of squared distances from smallest to largest and select the first k.
indices = np.argsort(sqdists, axis=1)[:, :k]

# Determine mose common class label in each row.
classes = ss.mode(self.T[indices,:][:,:,0], axis=1)[0]

return classes

In [73]:
knn = KNN()
knn

Out[73]:
<__main__.KNN at 0x7fed7e246910>

Oh, can't have that!!

In [74]:
import numpy as np
import scipy.stats as ss  # for ss.mode

class KNN():

def __init__(self):

self.X = None  # data will be stored here
self.T = None  # class labels will be stored here
self.Xmeans = None
self.Xstds = None

def __repr__(self):
if self.X is None:
return f'KNN() has not been trained.'
else:
return f'KNN(), trained with {self.X.shape[0]} samples having class labels {np.unique(self.T)}.'

def train(self, X, T):

if self.Xmeans is None:
self.Xmeans = X.mean(axis=0)
self.Xstds = X.std(axis=0)
self.Xstds[self.Xstds == 0] = 1

self.X = self._standardizeX(X)
self.T = T

return self

def _standardizeX(self, X):
return (X - self.Xmeans) / self.Xstds

def use(self, Xnew, k = 1):

if self.X is None:
raise Exception('KNN object has not been trained yet.')

self.k = k
# Calc squared distance from all samples in Xnew with all stored in self.X
sqdists = np.sum( (self._standardizeX(Xnew)[:,np.newaxis,:] - self.X)**2, axis=-1 )
# sqdist

# sqdists is now n_new_samples x n_train_samples
# Sort each row of squared distances from smallest to largest and select the first k.
indices = np.argsort(sqdists, axis=1)[:, :k]

# Determine mose common class label in each row.
classes = ss.mode(self.T[indices,:][:,:,0], axis=1)[0]

return classes

In [75]:
knn = KNN()
knn

Out[75]:
KNN() has not been trained.
In [76]:
knn.train(Xtrain, Ttrain)

Out[76]:
KNN(), trained with 50000 samples having class labels [0 1 2 3 4 5 6 7 8 9].

Boy, that took a long time to train! :) 200 ms.

Let's test it. First, use the default value for $k$ of 1.

In [77]:
knn.use(Xtest[:3, :])

Out[77]:
array([[7],
[2],
[1]])
In [78]:
Ttest[:3]

Out[78]:
array([[7],
[2],
[1]])

Well, that worked perfectly. Let's try more test samples.

In [79]:
knn.use(Xtest[:10, :])

Out[79]:
array([[7],
[2],
[1],
[0],
[4],
[1],
[4],
[4],
[4],
[9]])
In [80]:
Ttest[:10]

Out[80]:
array([[7],
[2],
[1],
[0],
[4],
[1],
[4],
[9],
[5],
[9]])

There are some mistakes. How about using more neighbors?

In [81]:
knn.use(Xtest[:10, :], k=5)

Out[81]:
array([[7],
[2],
[1],
[0],
[4],
[1],
[4],
[9],
[4],
[9]])
In [82]:
Ttest[:10]

Out[82]:
array([[7],
[2],
[1],
[0],
[4],
[1],
[4],
[9],
[5],
[9]])
In [83]:
def percent_correct(Predicted, T):
return 100 * np.mean(Predicted == T)

In [84]:
percent_correct(knn.use(Xtest[:10, :], k=5), Ttest[:10])

Out[84]:
90.0

Now we can try multiple values of $k$ with a for loop, and test all test samples.

In [85]:
pc = []
for k in range(1, 21):
pc = percent_correct(knn.use(Xtest, k=k), Ttest)
pc.append([k, pc])

---------------------------------------------------------------------------
MemoryError                               Traceback (most recent call last)
<ipython-input-85-7f8af0cc37ab> in <module>
1 pc = []
2 for k in range(1, 21):
----> 3     pc = percent_correct(knn.use(Xtest, k=k), Ttest)
4     pc.append([k, pc])

<ipython-input-74-e861359179fd> in use(self, Xnew, k)
39         self.k = k
40         # Calc squared distance from all samples in Xnew with all stored in self.X
---> 41         sqdists = np.sum( (self._standardizeX(Xnew)[:,np.newaxis,:] - self.X)**2, axis=-1 )
42         # sqdist
43

MemoryError: Unable to allocate 1.43 TiB for an array with shape (10000, 50000, 784) and data type float32

Well, here is what we often face when dealing with big data sets. K-nearest-neighbors calculates squared distances between each train and test sample. That can get huge.

We can deal with this the typical way of working with batches of data.

In [86]:
knn = KNN()
n_train = 10000
knn.train(Xtrain[:n_train, :], Ttrain[:n_train, :])

batch_size = 200
n_samples = Xtest.shape[0]
batches = [batch_size] * (n_samples // batch_size)
if sum(batches) < n_samples:
batches = batches.append(n_samples - sum(batches))

results = []
for k in range(1, 21):
print(f'k={k}', end=' ')
n_correct = 0
first = 0
for this_batch in batches:
print(f'{first}', end=',')
last = first + this_batch
X = Xtest[first:last, :]
T = Ttest[first:last, :]
n_correct += np.sum(knn.use(X, k=k) == T)
first += this_batch
pc = n_correct / n_samples * 100
results.append([k, pc])
print()

k=1 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=2 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=3 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=4 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=5 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=6 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=7 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=8 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=9 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=10 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=11 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=12 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=13 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=14 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=15 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=16 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=17 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=18 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=19 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=20 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,

In [92]:
results

Out[92]:
[[1, 91.07],
[2, 89.59],
[3, 91.17],
[4, 90.95],
[5, 91.03],
[6, 90.86999999999999],
[7, 90.73],
[8, 90.66],
[9, 90.63],
[10, 90.48],
[11, 90.33],
[12, 90.32],
[13, 90.27],
[14, 90.16],
[15, 90.21000000000001],
[16, 90.09],
[17, 89.87],
[18, 89.71000000000001],
[19, 89.63],
[20, 89.55]]
In [96]:
results = np.array(results)

plt.plot(results[:, 0], results[:, 1])
plt.xlabel('$k$')
plt.ylabel('Percent Correct Test Data');


How might you change the implementation of KNN to speed up this calculation using multiple $k$ values?

How could you calculate class probabilities with KNN?

## Comparison to Neural Network Classifier¶

In [87]:
n = 20
X = np.random.multivariate_normal([5, 7], [[0.8, -0.5], [-0.5, 0.8]], n)
X = np.vstack((X,
np.random.multivariate_normal([6, 3], [[0.6, 0.5], [0.5, 0.8]], n)))
T = np.vstack((np.ones((n, 1)), 2 * np.ones((n, 1))))

plt.scatter(X[:, 0], X[:, 1], c=T, s=80);

In [88]:
plt.figure(figsize=(10, 10))

n = 20
X = np.random.multivariate_normal([5, 7], [[0.8, -0.5], [-0.5, 0.8]], n)
X = np.vstack((X,
np.random.multivariate_normal([6, 3], [[0.6, 0.5], [0.5, 0.8]], n)))
T = np.vstack((np.ones((n, 1)), 2 * np.ones((n, 1))))

# Make samples as coordinates of grid points across 2-dimensional data space
m = 100
xs = np.linspace(0, 10, m)
ys = xs
Xs, Ys = np.meshgrid(xs, ys)
samples = np.vstack((Xs.ravel(), Ys.ravel())).T

knn = KNN()
knn.train(X, T)

classes = knn.use(samples, k=1)

plt.contourf(Xs, Ys, classes.reshape(Xs.shape), 1, colors=('blue','red'), alpha=0.2)
plt.scatter(X[:, 0], X[:, 1], s=60, c=T);

In [89]:
def plot_result(X, Xs, Ys, classes):
plt.contourf(Xs, Ys, classes.reshape(Xs.shape), 1, colors=('blue','red'), alpha=0.2)
plt.scatter(X[:, 0], X[:, 1], s=60, c=T);

In [90]:
import A4mysolution as nn

In [91]:
n = 40
X = np.random.multivariate_normal([5, 6], [[0.9, -0.2], [-0.2, 0.9]], n)
X = np.vstack((X,
np.random.multivariate_normal([6, 3], [[2, 0.4], [0.4, 2]], n)))
T = np.vstack((np.ones((n, 1)), 2 * np.ones((n, 1))))

m = 100
xs = np.linspace(1, 9, m)
ys = xs
Xs,Ys = np.meshgrid(xs, ys)
samples = np.vstack((Xs.ravel(), Ys.ravel())).T

plt.figure(figsize=(20, 30))

knn = KNN()
knn.train(X, T)

ploti = 0
for k in [1, 2, 3, 5, 10, 20]:
ploti += 1
plt.subplot(4, 3, ploti)

classes = knn.use(samples, k)
plot_result(X, Xs, Ys, classes)
plt.title(f'KNN k={k}')

for n_hiddens in [[], [1], [2], [10], [10, 10], [5, 5, 5, 5]]:
ploti += 1
plt.subplot(4, 3, ploti)

nnet = nn.NeuralNetworkClassifier(2, n_hiddens, 2)
nnet.train(X, T, n_epochs=500, verbose=False)
classes, _ = nnet.use(samples)
plot_result(X, Xs, Ys, classes)
plt.title(f'nnet {n_hiddens}')