K-Means Clustering

So far this semester we have been working with supervised and reinforcement learning algorithms. Another family of machine learning algorithms are unsupervised learning algorithms. These are algorithms designed to find patterns or groupings in a data set. No targets, or desired outputs, are involved.

Old Faithful Dataset

For example, take a look at this data set of eruption durations and the waiting times in between eruptions of the Old Faithful Geyser in Yellowstone National Park.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
In [2]:
!head faithful.csv
"","eruptions","waiting"
"1",3.6,79
"2",1.8,54
"3",3.333,74
"4",2.283,62
"5",4.533,85
"6",2.883,55
"7",4.7,88
"8",3.6,85
"9",1.95,51
In [3]:
datadf = pd.read_csv('faithful.csv', usecols=(1, 2))
datadf
Out[3]:
eruptions waiting
0 3.600 79
1 1.800 54
2 3.333 74
3 2.283 62
4 4.533 85
... ... ...
267 4.117 81
268 2.150 46
269 4.417 90
270 1.817 46
271 4.467 74

272 rows × 2 columns

In [4]:
data = datadf.values
data
Out[4]:
array([[ 3.6  , 79.   ],
       [ 1.8  , 54.   ],
       [ 3.333, 74.   ],
       [ 2.283, 62.   ],
       [ 4.533, 85.   ],
       [ 2.883, 55.   ],
       [ 4.7  , 88.   ],
       [ 3.6  , 85.   ],
       [ 1.95 , 51.   ],
       [ 4.35 , 85.   ],
       [ 1.833, 54.   ],
       [ 3.917, 84.   ],
       [ 4.2  , 78.   ],
       [ 1.75 , 47.   ],
       [ 4.7  , 83.   ],
       [ 2.167, 52.   ],
       [ 1.75 , 62.   ],
       [ 4.8  , 84.   ],
       [ 1.6  , 52.   ],
       [ 4.25 , 79.   ],
       [ 1.8  , 51.   ],
       [ 1.75 , 47.   ],
       [ 3.45 , 78.   ],
       [ 3.067, 69.   ],
       [ 4.533, 74.   ],
       [ 3.6  , 83.   ],
       [ 1.967, 55.   ],
       [ 4.083, 76.   ],
       [ 3.85 , 78.   ],
       [ 4.433, 79.   ],
       [ 4.3  , 73.   ],
       [ 4.467, 77.   ],
       [ 3.367, 66.   ],
       [ 4.033, 80.   ],
       [ 3.833, 74.   ],
       [ 2.017, 52.   ],
       [ 1.867, 48.   ],
       [ 4.833, 80.   ],
       [ 1.833, 59.   ],
       [ 4.783, 90.   ],
       [ 4.35 , 80.   ],
       [ 1.883, 58.   ],
       [ 4.567, 84.   ],
       [ 1.75 , 58.   ],
       [ 4.533, 73.   ],
       [ 3.317, 83.   ],
       [ 3.833, 64.   ],
       [ 2.1  , 53.   ],
       [ 4.633, 82.   ],
       [ 2.   , 59.   ],
       [ 4.8  , 75.   ],
       [ 4.716, 90.   ],
       [ 1.833, 54.   ],
       [ 4.833, 80.   ],
       [ 1.733, 54.   ],
       [ 4.883, 83.   ],
       [ 3.717, 71.   ],
       [ 1.667, 64.   ],
       [ 4.567, 77.   ],
       [ 4.317, 81.   ],
       [ 2.233, 59.   ],
       [ 4.5  , 84.   ],
       [ 1.75 , 48.   ],
       [ 4.8  , 82.   ],
       [ 1.817, 60.   ],
       [ 4.4  , 92.   ],
       [ 4.167, 78.   ],
       [ 4.7  , 78.   ],
       [ 2.067, 65.   ],
       [ 4.7  , 73.   ],
       [ 4.033, 82.   ],
       [ 1.967, 56.   ],
       [ 4.5  , 79.   ],
       [ 4.   , 71.   ],
       [ 1.983, 62.   ],
       [ 5.067, 76.   ],
       [ 2.017, 60.   ],
       [ 4.567, 78.   ],
       [ 3.883, 76.   ],
       [ 3.6  , 83.   ],
       [ 4.133, 75.   ],
       [ 4.333, 82.   ],
       [ 4.1  , 70.   ],
       [ 2.633, 65.   ],
       [ 4.067, 73.   ],
       [ 4.933, 88.   ],
       [ 3.95 , 76.   ],
       [ 4.517, 80.   ],
       [ 2.167, 48.   ],
       [ 4.   , 86.   ],
       [ 2.2  , 60.   ],
       [ 4.333, 90.   ],
       [ 1.867, 50.   ],
       [ 4.817, 78.   ],
       [ 1.833, 63.   ],
       [ 4.3  , 72.   ],
       [ 4.667, 84.   ],
       [ 3.75 , 75.   ],
       [ 1.867, 51.   ],
       [ 4.9  , 82.   ],
       [ 2.483, 62.   ],
       [ 4.367, 88.   ],
       [ 2.1  , 49.   ],
       [ 4.5  , 83.   ],
       [ 4.05 , 81.   ],
       [ 1.867, 47.   ],
       [ 4.7  , 84.   ],
       [ 1.783, 52.   ],
       [ 4.85 , 86.   ],
       [ 3.683, 81.   ],
       [ 4.733, 75.   ],
       [ 2.3  , 59.   ],
       [ 4.9  , 89.   ],
       [ 4.417, 79.   ],
       [ 1.7  , 59.   ],
       [ 4.633, 81.   ],
       [ 2.317, 50.   ],
       [ 4.6  , 85.   ],
       [ 1.817, 59.   ],
       [ 4.417, 87.   ],
       [ 2.617, 53.   ],
       [ 4.067, 69.   ],
       [ 4.25 , 77.   ],
       [ 1.967, 56.   ],
       [ 4.6  , 88.   ],
       [ 3.767, 81.   ],
       [ 1.917, 45.   ],
       [ 4.5  , 82.   ],
       [ 2.267, 55.   ],
       [ 4.65 , 90.   ],
       [ 1.867, 45.   ],
       [ 4.167, 83.   ],
       [ 2.8  , 56.   ],
       [ 4.333, 89.   ],
       [ 1.833, 46.   ],
       [ 4.383, 82.   ],
       [ 1.883, 51.   ],
       [ 4.933, 86.   ],
       [ 2.033, 53.   ],
       [ 3.733, 79.   ],
       [ 4.233, 81.   ],
       [ 2.233, 60.   ],
       [ 4.533, 82.   ],
       [ 4.817, 77.   ],
       [ 4.333, 76.   ],
       [ 1.983, 59.   ],
       [ 4.633, 80.   ],
       [ 2.017, 49.   ],
       [ 5.1  , 96.   ],
       [ 1.8  , 53.   ],
       [ 5.033, 77.   ],
       [ 4.   , 77.   ],
       [ 2.4  , 65.   ],
       [ 4.6  , 81.   ],
       [ 3.567, 71.   ],
       [ 4.   , 70.   ],
       [ 4.5  , 81.   ],
       [ 4.083, 93.   ],
       [ 1.8  , 53.   ],
       [ 3.967, 89.   ],
       [ 2.2  , 45.   ],
       [ 4.15 , 86.   ],
       [ 2.   , 58.   ],
       [ 3.833, 78.   ],
       [ 3.5  , 66.   ],
       [ 4.583, 76.   ],
       [ 2.367, 63.   ],
       [ 5.   , 88.   ],
       [ 1.933, 52.   ],
       [ 4.617, 93.   ],
       [ 1.917, 49.   ],
       [ 2.083, 57.   ],
       [ 4.583, 77.   ],
       [ 3.333, 68.   ],
       [ 4.167, 81.   ],
       [ 4.333, 81.   ],
       [ 4.5  , 73.   ],
       [ 2.417, 50.   ],
       [ 4.   , 85.   ],
       [ 4.167, 74.   ],
       [ 1.883, 55.   ],
       [ 4.583, 77.   ],
       [ 4.25 , 83.   ],
       [ 3.767, 83.   ],
       [ 2.033, 51.   ],
       [ 4.433, 78.   ],
       [ 4.083, 84.   ],
       [ 1.833, 46.   ],
       [ 4.417, 83.   ],
       [ 2.183, 55.   ],
       [ 4.8  , 81.   ],
       [ 1.833, 57.   ],
       [ 4.8  , 76.   ],
       [ 4.1  , 84.   ],
       [ 3.966, 77.   ],
       [ 4.233, 81.   ],
       [ 3.5  , 87.   ],
       [ 4.366, 77.   ],
       [ 2.25 , 51.   ],
       [ 4.667, 78.   ],
       [ 2.1  , 60.   ],
       [ 4.35 , 82.   ],
       [ 4.133, 91.   ],
       [ 1.867, 53.   ],
       [ 4.6  , 78.   ],
       [ 1.783, 46.   ],
       [ 4.367, 77.   ],
       [ 3.85 , 84.   ],
       [ 1.933, 49.   ],
       [ 4.5  , 83.   ],
       [ 2.383, 71.   ],
       [ 4.7  , 80.   ],
       [ 1.867, 49.   ],
       [ 3.833, 75.   ],
       [ 3.417, 64.   ],
       [ 4.233, 76.   ],
       [ 2.4  , 53.   ],
       [ 4.8  , 94.   ],
       [ 2.   , 55.   ],
       [ 4.15 , 76.   ],
       [ 1.867, 50.   ],
       [ 4.267, 82.   ],
       [ 1.75 , 54.   ],
       [ 4.483, 75.   ],
       [ 4.   , 78.   ],
       [ 4.117, 79.   ],
       [ 4.083, 78.   ],
       [ 4.267, 78.   ],
       [ 3.917, 70.   ],
       [ 4.55 , 79.   ],
       [ 4.083, 70.   ],
       [ 2.417, 54.   ],
       [ 4.183, 86.   ],
       [ 2.217, 50.   ],
       [ 4.45 , 90.   ],
       [ 1.883, 54.   ],
       [ 1.85 , 54.   ],
       [ 4.283, 77.   ],
       [ 3.95 , 79.   ],
       [ 2.333, 64.   ],
       [ 4.15 , 75.   ],
       [ 2.35 , 47.   ],
       [ 4.933, 86.   ],
       [ 2.9  , 63.   ],
       [ 4.583, 85.   ],
       [ 3.833, 82.   ],
       [ 2.083, 57.   ],
       [ 4.367, 82.   ],
       [ 2.133, 67.   ],
       [ 4.35 , 74.   ],
       [ 2.2  , 54.   ],
       [ 4.45 , 83.   ],
       [ 3.567, 73.   ],
       [ 4.5  , 73.   ],
       [ 4.15 , 88.   ],
       [ 3.817, 80.   ],
       [ 3.917, 71.   ],
       [ 4.45 , 83.   ],
       [ 2.   , 56.   ],
       [ 4.283, 79.   ],
       [ 4.767, 78.   ],
       [ 4.533, 84.   ],
       [ 1.85 , 58.   ],
       [ 4.25 , 83.   ],
       [ 1.983, 43.   ],
       [ 2.25 , 60.   ],
       [ 4.75 , 75.   ],
       [ 4.117, 81.   ],
       [ 2.15 , 46.   ],
       [ 4.417, 90.   ],
       [ 1.817, 46.   ],
       [ 4.467, 74.   ]])
In [5]:
plt.plot(data[:, 0], data[:, 1], '.')
plt.xlabel('duration')
plt.ylabel('interval')
Out[5]:
Text(0, 0.5, 'interval')

We can clearly see two clusters here. For higher dimensional data, we cannot directly visualize the data to see the clusters. We need a mathematical way to detect clusters. This gives rise to the class of unsupervised learning methods called clustering algorithms.

A simple example of a clustering algorithm is the k-means algorithm. It results in identifying $k$ cluster centers. It is an iterative algorithm that starts with an initial assignment of $k$ centers. Then it proceeds by determining which centers each data sample is closest to and adjusts the centers to be the means of each of these data partitions. It then repeats.

Let's develop this algorithm one step at a time.

Each sample is the Old Faithful data has 2 attributes, so each sample is in 2-dimensional space. We know by looking at the above plot that our data nicely falls in two clusters, so we will start with $k=2$. We will initialize the two cluster centers by randomly choosing two of the data samples.

In [6]:
n_samples = data.shape[0]
np.random.choice(range(n_samples), 2, replace=False)
Out[6]:
array([226,  66])
In [7]:
centers = data[np.random.choice(range(n_samples), 2, replace=False), :]
centers
Out[7]:
array([[ 2.2, 60. ],
       [ 4.7, 83. ]])

Now we must find all samples that are closest to the first center, and those that are closest to the second sample.

Check out the wonders of numpy broadcasting.

In [8]:
a = np.array([1, 2, 3])
b = np.array([10, 20, 30])
a, b
Out[8]:
(array([1, 2, 3]), array([10, 20, 30]))
In [9]:
a - b
Out[9]:
array([ -9, -18, -27])

But what if we want to subtract every element of a with every element of b?

In [10]:
np.resize(a, (3, 3))
Out[10]:
array([[1, 2, 3],
       [1, 2, 3],
       [1, 2, 3]])
In [11]:
np.resize(b, (3, 3))
Out[11]:
array([[10, 20, 30],
       [10, 20, 30],
       [10, 20, 30]])
In [12]:
np.resize(a, (3, 3)).T
Out[12]:
array([[1, 1, 1],
       [2, 2, 2],
       [3, 3, 3]])
In [13]:
np.resize(a, (3, 3)).T - np.resize(b, (3, 3))
Out[13]:
array([[ -9, -19, -29],
       [ -8, -18, -28],
       [ -7, -17, -27]])

However, we can ask numpy to do this duplication for us if we reshape a to be a column vector and leave b as a row vector.

$$ \begin{pmatrix} 1\\ 2\\ 3 \end{pmatrix} - \begin{pmatrix} 10 & 20 & 30 \end{pmatrix} \;\; = \;\; \begin{pmatrix} 1 & 1 & 1\\ 2 & 2 & 2\\ 3 & 3 & 3 \end{pmatrix} - \begin{pmatrix} 10 & 20 & 30\\ 10 & 20 & 30\\ 10 & 20 & 30 \end{pmatrix} $$
In [14]:
a[:, np.newaxis]
Out[14]:
array([[1],
       [2],
       [3]])
In [15]:
a[:, np.newaxis] - b
Out[15]:
array([[ -9, -19, -29],
       [ -8, -18, -28],
       [ -7, -17, -27]])

Now imagine that a is a cluster center and b contains data samples, one per row. The first step of calculating the distance from a to all samples in b is to subtract them component-wise.

In [16]:
a = np.array([1, 2, 3])
b = np.array([[10, 20, 30], [40, 50, 60]])
print(a)
print(b)
[1 2 3]
[[10 20 30]
 [40 50 60]]
In [17]:
b - a
Out[17]:
array([[ 9, 18, 27],
       [39, 48, 57]])

The single row vector a is duplicated for as many rows as there are in b! We can use this to calculate the squared distance between a center and every sample.

In [18]:
centers[0,:]
Out[18]:
array([ 2.2, 60. ])
In [19]:
sqdists_to_center_0 = np.sum((centers[0, :] - data)**2, axis=1)
sqdists_to_center_0
Out[19]:
array([3.62960000e+02, 3.61600000e+01, 1.97283689e+02, 4.00688900e+00,
       6.30442889e+02, 2.54664890e+01, 7.90250000e+02, 6.26960000e+02,
       8.10625000e+01, 6.29622500e+02, 3.61346890e+01, 5.78948089e+02,
       3.28000000e+02, 1.69202500e+02, 5.35250000e+02, 6.40010890e+01,
       4.20250000e+00, 5.82760000e+02, 6.43600000e+01, 3.65202500e+02,
       8.11600000e+01, 1.69202500e+02, 3.25562500e+02, 8.17516890e+01,
       2.01442889e+02, 5.30960000e+02, 2.50542890e+01, 2.59545689e+02,
       3.26722500e+02, 3.65986289e+02, 1.73410000e+02, 2.94139289e+02,
       3.73618890e+01, 4.03359889e+02, 1.98666689e+02, 6.40334890e+01,
       1.44110889e+02, 4.06932689e+02, 1.13468900e+00, 9.06671889e+02,
       4.04622500e+02, 4.10048900e+00, 5.81602689e+02, 4.20250000e+00,
       1.74442889e+02, 5.30247689e+02, 1.86666890e+01, 4.90100000e+01,
       4.89919489e+02, 1.04000000e+00, 2.31760000e+02, 9.06330256e+02,
       3.61346890e+01, 4.06932689e+02, 3.62180890e+01, 5.36198489e+02,
       1.23301289e+02, 1.62840890e+01, 2.94602689e+02, 4.45481689e+02,
       1.00108900e+00, 5.81290000e+02, 1.44202500e+02, 4.90760000e+02,
       1.46689000e-01, 1.02884000e+03, 3.27869089e+02, 3.30250000e+02,
       2.50176890e+01, 1.75250000e+02, 4.87359889e+02, 1.60542890e+01,
       3.66290000e+02, 1.24240000e+02, 4.04708900e+00, 2.64219689e+02,
       3.34890000e-02, 3.29602689e+02, 2.58832489e+02, 5.30960000e+02,
       2.28736489e+02, 4.88549689e+02, 1.03610000e+02, 2.51874890e+01,
       1.72485689e+02, 7.91469289e+02, 2.59062500e+02, 4.05368489e+02,
       1.44001089e+02, 6.79240000e+02, 0.00000000e+00, 9.04549689e+02,
       1.00110889e+02, 3.30848689e+02, 9.13468900e+00, 1.48410000e+02,
       5.82086089e+02, 2.27402500e+02, 8.11108890e+01, 4.91290000e+02,
       4.08008900e+00, 7.88695889e+02, 1.21010000e+02, 5.34290000e+02,
       4.44422500e+02, 1.69110889e+02, 5.82250000e+02, 6.41738890e+01,
       6.83022500e+02, 4.43199289e+02, 2.31416089e+02, 1.01000000e+00,
       8.48290000e+02, 3.65915089e+02, 1.25000000e+00, 4.46919489e+02,
       1.00013689e+02, 6.30760000e+02, 1.14668900e+00, 7.33915089e+02,
       4.91738890e+01, 8.44856890e+01, 2.93202500e+02, 1.60542890e+01,
       7.89760000e+02, 4.43455489e+02, 2.25080089e+02, 4.89290000e+02,
       2.50044890e+01, 9.06002500e+02, 2.25110889e+02, 5.32869089e+02,
       1.63600000e+01, 8.45549689e+02, 1.96134689e+02, 4.88765489e+02,
       8.11004890e+01, 6.83469289e+02, 4.90278890e+01, 3.63350089e+02,
       4.45133089e+02, 1.08900000e-03, 4.89442889e+02, 2.95848689e+02,
       2.60549689e+02, 1.04708900e+00, 4.05919489e+02, 1.21033489e+02,
       1.30441000e+03, 4.91600000e+01, 2.97025889e+02, 2.92240000e+02,
       2.50400000e+01, 4.46760000e+02, 1.22868689e+02, 1.03240000e+02,
       4.46290000e+02, 1.09254569e+03, 4.91600000e+01, 8.44122289e+02,
       2.25000000e+02, 6.79802500e+02, 4.04000000e+00, 3.26666689e+02,
       3.76900000e+01, 2.61678689e+02, 9.02788900e+00, 7.91840000e+02,
       6.40712890e+01, 1.09484189e+03, 1.21080089e+02, 9.01368900e+00,
       2.94678689e+02, 6.52836890e+01, 4.44869089e+02, 4.45549689e+02,
       1.74290000e+02, 1.00047089e+02, 6.28240000e+02, 1.99869089e+02,
       2.51004890e+01, 2.94678689e+02, 5.33202500e+02, 5.31455489e+02,
       8.10278890e+01, 3.28986289e+02, 5.79545689e+02, 1.96134689e+02,
       5.33915089e+02, 2.50002890e+01, 4.47760000e+02, 9.13468900e+00,
       2.62760000e+02, 5.79610000e+02, 2.92118756e+02, 4.45133089e+02,
       7.30690000e+02, 2.93691556e+02, 8.10025000e+01, 3.30086089e+02,
       1.00000000e-02, 4.88622500e+02, 9.64736489e+02, 4.91108890e+01,
       3.29760000e+02, 1.96173889e+02, 2.93695889e+02, 5.78722500e+02,
       1.21071289e+02, 5.34290000e+02, 1.21033489e+02, 4.06250000e+02,
       1.21110889e+02, 2.27666689e+02, 1.74810890e+01, 2.60133089e+02,
       4.90400000e+01, 1.16276000e+03, 2.50400000e+01, 2.59802500e+02,
       1.00110889e+02, 4.88272489e+02, 3.62025000e+01, 2.30212089e+02,
       3.27240000e+02, 3.64674889e+02, 3.27545689e+02, 3.28272489e+02,
       1.02948089e+02, 3.66522500e+02, 1.03545689e+02, 3.60470890e+01,
       6.79932289e+02, 1.00000289e+02, 9.05062500e+02, 3.61004890e+01,
       3.61225000e+01, 2.93338889e+02, 3.64062500e+02, 1.60176890e+01,
       2.28802500e+02, 1.69022500e+02, 6.83469289e+02, 9.49000000e+00,
       6.30678689e+02, 4.86666689e+02, 9.01368900e+00, 4.88695889e+02,
       4.90044890e+01, 2.00622500e+02, 3.60000000e+01, 5.34062500e+02,
       1.70868689e+02, 1.74290000e+02, 7.87802500e+02, 4.02614689e+02,
       1.23948089e+02, 5.34062500e+02, 1.60400000e+01, 3.65338889e+02,
       3.30589489e+02, 5.81442889e+02, 4.12250000e+00, 5.33202500e+02,
       2.89047089e+02, 2.50000000e-03, 2.31502500e+02, 4.44674889e+02,
       1.96002500e+02, 9.04915089e+02, 1.96146689e+02, 2.01139289e+02])
In [20]:
sqdists_to_center_1 = np.sum((centers[1, :] - data)**2, axis=1)
sqdists_to_center_1
Out[20]:
array([1.72100000e+01, 8.49410000e+02, 8.28686890e+01, 4.46841889e+02,
       4.02788900e+00, 7.87301489e+02, 2.50000000e+01, 5.21000000e+00,
       1.03156250e+03, 4.12250000e+00, 8.49219689e+02, 1.61308900e+00,
       2.52500000e+01, 1.30470250e+03, 0.00000000e+00, 9.67416089e+02,
       4.49702500e+02, 1.01000000e+00, 9.70610000e+02, 1.62025000e+01,
       1.03241000e+03, 1.30470250e+03, 2.65625000e+01, 1.98666689e+02,
       8.10278890e+01, 1.21000000e+00, 7.91469289e+02, 4.93806890e+01,
       2.57225000e+01, 1.60712890e+01, 1.00160000e+02, 3.60542890e+01,
       2.90776889e+02, 9.44488900e+00, 8.17516890e+01, 9.68198489e+02,
       1.23302589e+03, 9.01768900e+00, 5.84219689e+02, 4.90068890e+01,
       9.12250000e+00, 6.32935489e+02, 1.01768900e+00, 6.33702500e+02,
       1.00027889e+02, 1.91268900e+00, 3.61751689e+02, 9.06760000e+02,
       1.00448900e+00, 5.83290000e+02, 6.40100000e+01, 4.90002560e+01,
       8.49219689e+02, 9.01768900e+00, 8.49803089e+02, 3.34890000e-02,
       1.44966289e+02, 3.70199089e+02, 3.60176890e+01, 4.14668900e+00,
       5.82086089e+02, 1.04000000e+00, 1.23370250e+03, 1.01000000e+00,
       5.37311689e+02, 8.10900000e+01, 2.52840890e+01, 2.50000000e+01,
       3.30932689e+02, 1.00000000e+02, 1.44488900e+00, 7.36469289e+02,
       1.60400000e+01, 1.44490000e+02, 4.48382089e+02, 4.91346890e+01,
       5.36198489e+02, 2.50176890e+01, 4.96674890e+01, 1.21000000e+00,
       6.43214890e+01, 1.13468900e+00, 1.69360000e+02, 3.28272489e+02,
       1.00400689e+02, 2.50542890e+01, 4.95625000e+01, 9.03348900e+00,
       1.23141609e+03, 9.49000000e+00, 5.35250000e+02, 4.91346890e+01,
       1.09702589e+03, 2.50136890e+01, 4.08219689e+02, 1.21160000e+02,
       1.00108900e+00, 6.49025000e+01, 1.03202589e+03, 1.04000000e+00,
       4.45915089e+02, 2.51108890e+01, 1.16276000e+03, 4.00000000e-02,
       4.42250000e+00, 1.30402589e+03, 1.00000000e+00, 9.69508889e+02,
       9.02250000e+00, 5.03428900e+00, 6.40010890e+01, 5.81760000e+02,
       3.60400000e+01, 1.60800890e+01, 5.85000000e+02, 4.00448900e+00,
       1.09467869e+03, 4.01000000e+00, 5.84311689e+02, 1.60800890e+01,
       9.04338889e+02, 1.96400689e+02, 3.62025000e+01, 7.36469289e+02,
       2.50100000e+01, 4.87048900e+00, 1.45174509e+03, 1.04000000e+00,
       7.89919489e+02, 4.90025000e+01, 1.45202589e+03, 2.84089000e-01,
       7.32610000e+02, 3.61346890e+01, 1.37721969e+03, 1.10048900e+00,
       1.03193549e+03, 9.05428900e+00, 9.07112889e+02, 1.69350890e+01,
       4.21808900e+00, 5.35086089e+02, 1.02788900e+00, 3.60136890e+01,
       4.91346890e+01, 5.83382089e+02, 9.00448900e+00, 1.16319849e+03,
       1.69160000e+02, 9.08410000e+02, 3.61108890e+01, 3.64900000e+01,
       3.29290000e+02, 4.01000000e+00, 1.45283689e+02, 1.69490000e+02,
       4.04000000e+00, 1.00380689e+02, 9.08410000e+02, 3.65372890e+01,
       1.45025000e+03, 9.30250000e+00, 6.32290000e+02, 2.57516890e+01,
       2.90440000e+02, 4.90136890e+01, 4.05442889e+02, 2.50900000e+01,
       9.68656289e+02, 1.00006889e+02, 1.16374509e+03, 6.82848689e+02,
       3.60136890e+01, 2.26868689e+02, 4.28408900e+00, 4.13468900e+00,
       1.00040000e+02, 1.09421209e+03, 4.49000000e+00, 8.12840890e+01,
       7.91935489e+02, 3.60136890e+01, 2.02500000e-01, 8.70489000e-01,
       1.03111289e+03, 2.50712890e+01, 1.38068900e+00, 1.37721969e+03,
       8.00890000e-02, 7.90335289e+02, 4.01000000e+00, 6.84219689e+02,
       4.90100000e+01, 1.36000000e+00, 3.65387560e+01, 4.21808900e+00,
       1.74400000e+01, 3.61115560e+01, 1.03000250e+03, 2.50010890e+01,
       5.35760000e+02, 1.12250000e+00, 6.43214890e+01, 9.08025889e+02,
       2.50100000e+01, 1.37750889e+03, 3.61108890e+01, 1.72250000e+00,
       1.16365629e+03, 4.00000000e-02, 1.49368489e+02, 9.00000000e+00,
       1.16402589e+03, 6.47516890e+01, 3.62646089e+02, 4.92180890e+01,
       9.05290000e+02, 1.21010000e+02, 7.91290000e+02, 4.93025000e+01,
       1.09702589e+03, 1.18748900e+00, 8.49702500e+02, 6.40470890e+01,
       2.54900000e+01, 1.63398890e+01, 2.53806890e+01, 2.51874890e+01,
       1.69613089e+02, 1.60225000e+01, 1.69380689e+02, 8.46212089e+02,
       9.26728900e+00, 1.09516529e+03, 4.90625000e+01, 8.48935489e+02,
       8.49122500e+02, 3.61738890e+01, 1.65625000e+01, 3.66602689e+02,
       6.43025000e+01, 1.30152250e+03, 9.05428900e+00, 4.03240000e+02,
       4.01368900e+00, 1.75168900e+00, 6.82848689e+02, 1.11088900e+00,
       2.62589489e+02, 8.11225000e+01, 8.47250000e+02, 6.25000000e-02,
       1.01283689e+02, 1.00040000e+02, 2.53025000e+01, 9.77968900e+00,
       1.44613089e+02, 6.25000000e-02, 7.36290000e+02, 1.61738890e+01,
       2.50044890e+01, 1.02788900e+00, 6.33122500e+02, 2.02500000e-01,
       1.60738209e+03, 5.35002500e+02, 6.40025000e+01, 4.33988900e+00,
       1.37550250e+03, 4.90800890e+01, 1.37731169e+03, 8.10542890e+01])

And, which samples are closest to the first center?

In [21]:
sqdists_to_center_0 < sqdists_to_center_1
Out[21]:
array([False,  True, False,  True, False,  True, False, False,  True,
       False,  True, False, False,  True, False,  True,  True, False,
        True, False,  True,  True, False,  True, False, False,  True,
       False, False, False, False, False,  True, False, False,  True,
        True, False,  True, False, False,  True, False,  True, False,
       False,  True,  True, False,  True, False, False,  True, False,
        True, False,  True,  True, False, False,  True, False,  True,
       False,  True, False, False, False,  True, False, False,  True,
       False,  True,  True, False,  True, False, False, False, False,
       False,  True,  True, False, False, False, False,  True, False,
        True, False,  True, False,  True, False, False, False,  True,
       False,  True, False,  True, False, False,  True, False,  True,
       False, False, False,  True, False, False,  True, False,  True,
       False,  True, False,  True,  True, False,  True, False, False,
        True, False,  True, False,  True, False,  True, False,  True,
       False,  True, False,  True, False, False,  True, False, False,
       False,  True, False,  True, False,  True, False, False,  True,
       False,  True,  True, False, False,  True, False,  True, False,
        True, False,  True, False,  True, False,  True, False,  True,
        True, False,  True, False, False, False,  True, False, False,
        True, False, False, False,  True, False, False,  True, False,
        True, False,  True, False, False, False, False, False, False,
        True, False,  True, False, False,  True, False,  True, False,
       False,  True, False,  True, False,  True, False,  True, False,
        True, False,  True, False,  True, False,  True, False, False,
       False, False, False,  True, False,  True,  True, False,  True,
       False,  True,  True, False, False,  True, False,  True, False,
        True, False, False,  True, False,  True, False,  True, False,
       False, False, False, False,  True, False,  True, False, False,
       False,  True, False,  True,  True, False, False,  True, False,
        True, False])

This approach is easy for $k=2$, but what if $k$ is larger. Can we calculate all of the needed distances in one numpy expression? I bet we can!

In [22]:
centers[:,np.newaxis,:].shape, data.shape
Out[22]:
((2, 1, 2), (272, 2))
In [23]:
(centers[:,np.newaxis,:] - data).shape
Out[23]:
(2, 272, 2)
In [24]:
np.sum((centers[:,np.newaxis,:] - data)**2, axis=2).shape
Out[24]:
(2, 272)

These are the square distances between each of our two centers and each of the 272 samples. If we take the argmin across the two rows, we will have the index of the closest center for each of the 272 samples.

In [25]:
clusters = np.argmin(np.sum((centers[:,np.newaxis,:] - data)**2, axis=2), axis=0)
clusters
Out[25]:
array([1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0,
       1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0,
       1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1,
       1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,
       0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1,
       1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1,
       0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1,
       0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1,
       1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1,
       0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
       0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0,
       1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1,
       0, 0, 1, 1, 0, 1, 0, 1])

Now, to calculate the new values of our two centers, we just calculate the mean of the appropriate samples.

In [26]:
data[clusters == 0, :].mean(axis=0)
Out[26]:
array([ 2.26414286, 56.39285714])
In [27]:
data[clusters == 1, :].mean(axis=0)
Out[27]:
array([ 4.34433125, 81.05      ])

Can do both in a for loop.

In [28]:
k = 2
for i in range(k):
    centers[i, :] = data[clusters == i, :].mean(axis=0)
In [29]:
centers
Out[29]:
array([[ 2.26414286, 56.39285714],
       [ 4.34433125, 81.05      ]])

Now, we can wrap these steps in our first version of a kmeans function.

In [30]:
def kmeans(data, k = 2, n_iterations = 5):
    
    # Initial centers
    centers = data[np.random.choice(range(data.shape[0]), k, replace=False), :]
    
    # Repeat n times
    for iteration in range(n_iterations):
        
        # Which center is each sample closest to?
        closest = np.argmin(np.sum((centers[:, np.newaxis, :] - data)**2, axis=2), axis=0)
        
        # Update cluster centers
        for i in range(k):
            centers[i, :] = data[closest == i, :].mean(axis=0)
            
    return centers
In [31]:
kmeans(data, 2)
Out[31]:
array([[ 2.09433   , 54.75      ],
       [ 4.29793023, 80.28488372]])
In [32]:
kmeans(data, 2)
Out[32]:
array([[ 4.29793023, 80.28488372],
       [ 2.09433   , 54.75      ]])

We need a measure of the quality of our clustering. For this, we define $J$, which is a performance measure being minimized by k-means. It is defined as $$ J = \sum_{n=1}^N \sum_{k=1}^K r_{nk} ||\mathbf{x}_n - \mathbf{\mu}_k||^2 $$ where $N$ is the number of samples, $K$ is the number of cluster centers, $\mathbf{x}_n$ is the $n^{th}$ sample and $\mathbf{\mu}_k$ is the $k^{th}$ center, each being an element of $\mathbf{R}^p$ where $p$ is the dimensionality of the data. $r_{nk}$ is 1 if $\mathbf{x}_n$ is closest to center $\mathbf{\mu}_k$, and 0 otherwise.

The sums can be computed using python for loops, but, as you know, for loops are much slower than matrix operations in python, so let's do the matrix magic. We already know how to calculate the difference between all samples and all centers.

In [33]:
sqdists = np.sum((centers[:,np.newaxis,:] - data)**2, axis=2)
sqdists.shape
Out[33]:
(2, 272)

The calculation of $J$ requires us to multiply the squared differences of the each component by $r_{nk}$. Since we already have all of the squared distances, let's just sum up the minimum distances for each sample.

In [34]:
np.min(sqdists, axis=0)
Out[34]:
array([4.75652901e+00, 5.94119390e+00, 5.07252909e+01, 3.14404066e+01,
       1.56380959e+01, 2.32303518e+00, 4.84290003e+01, 1.61565290e+01,
       2.91815939e+01, 1.56025321e+01, 5.91164947e+00, 8.88511200e+00,
       9.32333151e+00, 8.84901082e+01, 3.92900026e+00, 1.93066306e+01,
       3.17043939e+01, 8.91013401e+00, 1.97382796e+01, 4.21139838e+00,
       2.92983368e+01, 8.84901082e+01, 1.01023284e+01, 1.46834075e+02,
       4.97380959e+01, 4.35652901e+00, 2.02834490e+00, 2.55707940e+01,
       9.54686338e+00, 4.21036215e+00, 6.48044653e+01, 1.64175476e+01,
       9.35134878e+01, 1.19942715e+00, 4.99639596e+01, 1.93582735e+01,
       7.05977735e+01, 1.34129715e+00, 6.98307804e+00, 8.02949303e+01,
       1.10253213e+00, 2.72817804e+00, 8.75208137e+00, 2.84725104e+00,
       6.48380959e+01, 4.85790950e+00, 6.03299352e+01, 1.15384225e+01,
       9.85829647e-01, 6.86696533e+00, 3.68101340e+01, 8.02406377e+01,
       5.91164947e+00, 1.34129715e+00, 6.00787804e+00, 4.09266402e+00,
       1.01396044e+02, 5.82252020e+01, 1.64520814e+01, 3.24699723e-03,
       6.79816376e+00, 8.72673276e+00, 7.07043939e+01, 1.11013401e+00,
       1.32114163e+01, 1.19905599e+02, 9.33394637e+00, 9.42900026e+00,
       7.41217735e+01, 6.49290003e+01, 9.99427147e-01, 2.42630612e-01,
       4.22673276e+00, 1.01121064e+02, 3.15190923e+01, 2.60247501e+01,
       1.30725592e+01, 9.35208137e+00, 2.57153265e+01, 4.35652901e+00,
       3.66471609e+01, 9.02628397e-01, 1.22162198e+02, 7.42189638e+01,
       6.48794126e+01, 4.86490309e+01, 2.56579971e+01, 1.13231450e+00,
       7.04494878e+01, 2.46210640e+01, 1.30155939e+01, 8.01026284e+01,
       4.10263449e+01, 9.52591575e+00, 4.38402209e+01, 8.19044653e+01,
       8.80661512e+00, 3.69557296e+01, 2.92406306e+01, 1.21126776e+00,
       3.14879495e+01, 4.83030139e+01, 5.46812796e+01, 3.82673276e+00,
       8.91308847e-02, 8.83834878e+01, 8.82900026e+00, 1.95286923e+01,
       2.47582009e+01, 4.39859022e-01, 3.67535634e+01, 6.79847961e+00,
       6.35112678e+01, 4.20778075e+00, 7.11545104e+00, 8.58296472e-02,
       4.08714163e+01, 1.56678665e+01, 6.99713061e+00, 3.54077807e+01,
       1.16359878e+01, 1.45279413e+02, 1.64113984e+01, 2.42630612e-01,
       4.83678665e+01, 3.35811372e-01, 1.29917702e+02, 9.26732760e-01,
       1.94005918e+00, 8.01959334e+01, 1.29954916e+02, 3.83394637e+00,
       4.41479612e-01, 6.32026284e+01, 1.08197364e+02, 9.03995272e-01,
       2.92281780e+01, 2.48490309e+01, 1.15649066e+01, 4.57622590e+00,
       1.48946472e-02, 1.30124495e+01, 9.38095897e-01, 1.66259157e+01,
       2.55026284e+01, 6.87623518e+00, 1.18582965e+00, 5.47154163e+01,
       2.24073535e+02, 1.17269082e+01, 1.68767646e+01, 1.65210640e+01,
       7.41013653e+01, 6.78665097e-02, 1.01606744e+02, 1.22221064e+02,
       2.67327597e-02, 1.42870794e+02, 1.17269082e+01, 6.33448789e+01,
       1.29801308e+02, 2.45402646e+01, 2.65267961e+00, 9.56395965e+00,
       9.38245368e+01, 2.55594628e+01, 4.36649163e+01, 4.87324015e+01,
       1.94068495e+01, 1.42876848e+02, 5.47748449e+01, 4.01435184e-01,
       1.64594628e+01, 1.35868221e+02, 3.39463722e-02, 2.62839723e-03,
       6.48267328e+01, 4.08919878e+01, 1.57210640e+01, 4.97339464e+01,
       2.08532090e+00, 1.64594628e+01, 3.81139838e+00, 4.13581137e+00,
       2.91363352e+01, 9.31036215e+00, 8.77079402e+00, 1.08197364e+02,
       3.80778075e+00, 1.94663518e+00, 2.10134010e-01, 5.54506612e-01,
       2.57101340e+01, 8.76219776e+00, 1.65456345e+01, 1.48946472e-02,
       3.61153953e+01, 1.64029695e+01, 2.90831082e+01, 9.40661512e+00,
       1.30384225e+01, 9.02532135e-01, 9.90471609e+01, 1.16692020e+01,
       9.36786651e+00, 1.08242978e+02, 1.64030139e+01, 8.94686338e+00,
       5.47639923e+01, 3.82673276e+00, 1.04849320e+02, 1.22900026e+00,
       5.48120592e+01, 3.68639596e+01, 5.91977020e+01, 2.55148946e+01,
       1.15299368e+01, 1.67910134e+02, 2.00982247e+00, 2.55402646e+01,
       4.10263449e+01, 9.08480122e-01, 5.99010818e+00, 3.66217290e+01,
       9.42106401e+00, 4.25417950e+00, 9.37079402e+00, 9.30848012e+00,
       1.22285112e+02, 4.24479963e+00, 1.22170794e+02, 5.74913061e+00,
       2.45285278e+01, 4.08708449e+01, 8.01136659e+01, 5.87103518e+00,
       5.89727961e+00, 1.64062615e+01, 4.35799713e+00, 5.78733638e+01,
       3.66402646e+01, 8.82331368e+01, 2.48490309e+01, 4.40586510e+01,
       1.56594628e+01, 1.16395965e+00, 4.01435184e-01, 9.03013872e-01,
       1.12528678e+02, 4.97025321e+01, 5.72987961e+00, 3.81366588e+00,
       6.54067439e+01, 6.48267328e+01, 4.83402646e+01, 1.38057825e+00,
       1.01185112e+02, 3.81366588e+00, 2.24108184e-01, 4.20626152e+00,
       9.48114887e+00, 8.73809590e+00, 2.75442247e+00, 3.81139838e+00,
       1.79447664e+02, 1.30116796e+01, 3.67670671e+01, 5.41794972e-02,
       1.08024508e+02, 8.01077807e+01, 1.08211416e+02, 4.97175476e+01])
In [35]:
np.sum(np.min(sqdists, axis=0))
Out[35]:
9240.152878462477

Let's define a function named calcJ to do this calculation.

In [36]:
def calcJ(data, centers):
    sqdists = np.sum((centers[:,np.newaxis,:] - data)**2, axis=2)
    return np.sum(np.min(sqdists, axis=0))
In [37]:
calcJ(data, centers)
Out[37]:
9240.152878462477

Now we can add this calculation to track the value of $J$ for each iteration as a kind of learning curve. $J$ measures the average "spread" within each cluster, so the smaller it is, the better.

In [38]:
def kmeans(data, k = 2, n_iterations = 5):
    
    # Initialize centers and list J to track performance metric
    centers = data[np.random.choice(range(data.shape[0]), k, replace=False), :]
    J = []
    
    for iteration in range(n_iterations):
        
        # Which center is each sample closest to?
        sqdistances = np.sum((centers[:, np.newaxis, :] - data)**2, axis=2)
        closest = np.argmin(sqdistances, axis=0)
        
        # Calculate J and append to list J
        J.append(calcJ(data, centers))
        
        # Update cluster centers
        for i in range(k):
            centers[i, :] = data[closest == i,:].mean(axis=0)
            
    # Calculate J one final time and return results
    J.append(calcJ(data, centers))
    return centers, J, closest
In [39]:
centers, J, closest = kmeans(data, 2)
In [40]:
J
Out[40]:
[53997.652259999995,
 12341.837671828205,
 9134.49185083075,
 8904.34103114802,
 8901.76872094721,
 8901.76872094721]
In [41]:
plt.plot(J);
In [42]:
centers, J, closest = kmeans(data, 2, 10)
plt.plot(J);
In [43]:
centers
Out[43]:
array([[ 2.09433   , 54.75      ],
       [ 4.29793023, 80.28488372]])
In [44]:
closest
Out[44]:
array([1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0,
       1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0,
       1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1,
       1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
       0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1,
       1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1,
       0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1,
       1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1,
       1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1,
       0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1,
       0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0,
       1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1,
       0, 0, 1, 1, 0, 1, 0, 1])
In [45]:
centers, J, closest = kmeans(data, 2, 2)

plt.figure(figsize=(15, 8))

plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=80, c=closest, alpha=0.5)
plt.scatter(centers[:, 0], centers[:, 1], s=80, c="green")

plt.subplot(1, 2, 2)
plt.plot(J)

centers
Out[45]:
array([[ 2.09393939, 54.62626263],
       [ 4.28541618, 80.20809249]])
In [46]:
centers, J, closest = kmeans(data, 2, 10)

plt.figure(figsize=(15, 8))

plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=80, c=closest, alpha=0.5)
plt.scatter(centers[:, 0], centers[:, 1], s=80, c="green")

plt.subplot(1, 2, 2)
plt.plot(J)

centers
Out[46]:
array([[ 2.09433   , 54.75      ],
       [ 4.29793023, 80.28488372]])
In [47]:
centers, J, closest = kmeans(data, 3, 10)

plt.figure(figsize=(15, 8))

plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=80, c=closest, alpha=0.5)
plt.scatter(centers[:, 0], centers[:, 1], s=80, c="green")

plt.subplot(1, 2, 2)
plt.plot(J)

centers
Out[47]:
array([[ 2.00784848, 51.25757576],
       [ 2.6673617 , 63.93617021],
       [ 4.34461006, 81.10691824]])
In [48]:
centers, J, closest = kmeans(data, 4, 10)

plt.figure(figsize=(15, 8))

plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=80, c=closest, alpha=0.5)
plt.scatter(centers[:, 0], centers[:, 1], s=80, c="green")

plt.subplot(1, 2, 2)
plt.plot(J)

centers
Out[48]:
array([[ 4.3690119 , 84.91666667],
       [ 4.2403908 , 75.95402299],
       [ 2.26965789, 61.34210526],
       [ 2.0082381 , 50.98412698]])
In [49]:
centers, J, closest = kmeans(data, 6, 20)

plt.figure(figsize=(15, 8))

plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=80, c=closest, alpha=0.5)
plt.scatter(centers[:, 0], centers[:, 1], s=80, c="green")

plt.subplot(1, 2, 2)
plt.plot(J)

centers
Out[49]:
array([[ 3.00344828, 66.48275862],
       [ 2.02545098, 55.82352941],
       [ 1.9745625 , 48.15625   ],
       [ 4.29970588, 76.39705882],
       [ 4.49343478, 89.91304348],
       [ 4.3386087 , 82.68115942]])

MNIST Dataset

So, clustering two-dimensional data is not all that exciting. How about 784-dimensional data, such as our good buddy the MNIST data set?

In [50]:
import gzip
import pickle

with gzip.open('mnist.pkl.gz', 'rb') as f:
    train_set, valid_set, test_set = pickle.load(f, encoding='latin1')

Xtrain = train_set[0]
Ttrain = train_set[1].reshape((-1,1))

Xtest = test_set[0]
Ttest = test_set[1].reshape((-1,1))

Xtrain.shape, Ttrain.shape, Xtest.shape, Ttest.shape
Out[50]:
((50000, 784), (50000, 1), (10000, 784), (10000, 1))

How many clusters shall we use?

In [51]:
centers, J, closest = kmeans(Xtrain, k=10, n_iterations=10)
In [52]:
plt.plot(J);
In [53]:
centers.shape
Out[53]:
(10, 784)
In [54]:
for i in range(10):
    plt.subplot(2, 5, i + 1)
    plt.imshow(-centers[i, :].reshape((28, 28)), cmap='gray')
    plt.axis('off')
In [55]:
centers, J, closest = kmeans(Xtrain, k=10, n_iterations=20)
plt.plot(J)
plt.figure()
for i in range(10):
    plt.subplot(2, 5, i + 1)
    plt.imshow(-centers[i, :].reshape((28, 28)), cmap='gray')
    plt.axis('off')
In [56]:
centers, J, closest = kmeans(Xtrain, k=20, n_iterations=20)
plt.plot(J)
plt.figure()
for i in range(20):
    plt.subplot(4, 5, i + 1)
    plt.imshow(-centers[i, :].reshape((28, 28)), interpolation='nearest', cmap='gray')
    plt.axis('off')
In [57]:
centers, J, closest = kmeans(Xtrain, k=20, n_iterations=20)
plt.plot(J)
plt.figure()
for i in range(20):
    plt.subplot(4, 5, i + 1)
    plt.imshow(-centers[i, :].reshape((28, 28)), interpolation='nearest', cmap='gray')
    plt.axis('off')
In [58]:
centers, J, closest = kmeans(Xtrain, k=40, n_iterations=20)
plt.plot(J)
plt.figure()
for i in range(40):
    plt.subplot(4, 10, i + 1)
    plt.imshow(-centers[i, :].reshape((28, 28)), interpolation='nearest', cmap='gray')
    plt.axis('off')

How could you use the results of the kmeans clustering algorithm as the first step in a classification algorithm?

K-Nearest-Neighbor Classification

Now that we have some experience in calculating distances between samples, we are a short step away from an implementation of a common classification algorithm called k-nearest-neighbor. This is a non-parametric algorithm, meaning that it does not involve parameters, like weights, to make its decisions. Instead, we could call it a memory-based method. The algorithm classifies a sample by determining the $k$ closest samples in the training set and returns the most common class label among those $k$ nearest samples.

Training is terribly simple. We just have to store the training samples. Classification is also trivial to code. We just calculate squared distances between training samples and the samples being classified and return the most common class label among the $k$ closest training samples.

Let's create a class named KNN to implement this algorithm.

First, let's practice our numpy-foo to see how to pick the most common class, with a minimum amount of code.

Remember that sqdists from above is n_centers x n_samples.

Let's pretend we wish to classify the first three MNIST test samples.

In [59]:
sqdists = np.sum((Xtest[:3,np.newaxis,:] - Xtrain)**2, axis=2)
sqdists.shape
Out[59]:
(3, 50000)

Okay. Now all we have to do is find the $k$ smallest distances in each row. Let's use $k=5$.

In [60]:
k = 5
np.sort(sqdists[0, :])[:k]
Out[60]:
array([ 9.6193695, 11.355759 , 11.403915 , 12.214478 , 12.627594 ],
      dtype=float32)

But, we need the indices of these values so we can look up their class labels in T.

In [61]:
k = 5
np.argsort(sqdists[0, :])[:k]
Out[61]:
array([38620, 16186, 27059, 47003, 14563])

Now we have to do this for each row in sqdists. Or do we? Wouldn't it be nice if np.argsort sorts each row independently so we can do this in one function call?

In [62]:
np.sort(sqdists, axis=1)
Out[62]:
array([[  9.6193695,  11.355759 ,  11.403915 , ..., 211.29517  ,
        211.7034   , 235.06609  ],
       [ 20.636139 ,  22.408554 ,  25.232117 , ..., 215.19962  ,
        216.1485   , 221.52444  ],
       [  1.6865845,   1.7748108,   2.063202 , ..., 227.82173  ,
        245.48021  , 251.6908   ]], dtype=float32)

Yippee!

In [63]:
np.argsort(sqdists, axis=1)
Out[63]:
array([[38620, 16186, 27059, ..., 10259, 25321, 41358],
       [28882, 49160, 24612, ..., 43452, 10237, 13650],
       [46512, 15224, 47333, ..., 25321, 25285, 41358]])
In [64]:
indices = np.argsort(sqdists, axis=1)
indices
Out[64]:
array([[38620, 16186, 27059, ..., 10259, 25321, 41358],
       [28882, 49160, 24612, ..., 43452, 10237, 13650],
       [46512, 15224, 47333, ..., 25321, 25285, 41358]])
In [65]:
Ttrain[indices, :]
Out[65]:
array([[[7],
        [7],
        [7],
        ...,
        [0],
        [8],
        [0]],

       [[2],
        [2],
        [2],
        ...,
        [4],
        [0],
        [4]],

       [[1],
        [1],
        [1],
        ...,
        [8],
        [0],
        [0]]])
In [66]:
Ttrain[indices, :][:, :, 0]
Out[66]:
array([[7, 7, 7, ..., 0, 8, 0],
       [2, 2, 2, ..., 4, 0, 4],
       [1, 1, 1, ..., 8, 0, 0]])

Cool! Now we just have to take the first $k$ columns of these and determine the most common label across the columns, for each row. We can use scipy.stats.mode for this!

In [67]:
import scipy.stats as ss
ss.mode([1, 2, 3, 4, 2, 2, 2])
Out[67]:
ModeResult(mode=array([2]), count=array([4]))
In [68]:
ss.mode([1, 2, 3, 4, 2, 2, 2])[0]
Out[68]:
array([2])
In [69]:
ss.mode([1, 2, 3, 4, 2, 2, 2])[0]
Out[69]:
array([2])
In [70]:
ss.mode(Ttrain[indices, :][:, :, 0], axis=1)[0]
Out[70]:
array([[1],
       [1],
       [1]])
In [71]:
Ttest[:3]
Out[71]:
array([[7],
       [2],
       [1]])

Well, maybe we will do better with different values of $k$.

Finally, we can now define our KNN class.

In [72]:
import numpy as np
import scipy.stats as ss  # for ss.mode

class KNN():
    
    def __init__(self):
        
        self.X = None  # data will be stored here
        self.T = None  # class labels will be stored here
        self.Xmeans = None
        self.Xstds = None
    
    def train(self, X, T):
        
        if self.Xmeans is None:
            self.Xmeans = X.mean(axis=0)
            self.Xstds = X.std(axis=0)
            self.Xstds[self.Xstds == 0] = 1
            
        self.X = self._standardizeX(X)
        self.T = T
        
    def _standardizeX(self, X):
        return (X - self.Xmeans) / self.Xstds

    def use(self, Xnew, k = 1):
        self.k = k
        # Calc squared distance from all samples in Xnew with all stored in self.X
        sqdists = np.sum( (self._standardizeX(Xnew)[:,np.newaxis,:] - self.X)**2, axis=-1 )
        # sqdist

        # sqdists is now n_new_samples x n_train_samples
        # Sort each row of squared distances from smallest to largest and select the first k.
        indices = np.argsort(sqdists, axis=1)[:, :k]

        # Determine mose common class label in each row.
        classes = ss.mode(self.T[indices,:][:,:,0], axis=1)[0]
        
        return classes
In [73]:
knn = KNN()
knn
Out[73]:
<__main__.KNN at 0x7fed7e246910>

Oh, can't have that!!

In [74]:
import numpy as np
import scipy.stats as ss  # for ss.mode

class KNN():
    
    def __init__(self):
        
        self.X = None  # data will be stored here
        self.T = None  # class labels will be stored here
        self.Xmeans = None
        self.Xstds = None
        
    def __repr__(self):
        if self.X is None:
            return f'KNN() has not been trained.'
        else:
            return f'KNN(), trained with {self.X.shape[0]} samples having class labels {np.unique(self.T)}.'
    
    def train(self, X, T):
        
        if self.Xmeans is None:
            self.Xmeans = X.mean(axis=0)
            self.Xstds = X.std(axis=0)
            self.Xstds[self.Xstds == 0] = 1
            
        self.X = self._standardizeX(X)
        self.T = T
        
        return self
        
    def _standardizeX(self, X):
        return (X - self.Xmeans) / self.Xstds

    def use(self, Xnew, k = 1):
        
        if self.X is None:
            raise Exception('KNN object has not been trained yet.')
            
        self.k = k
        # Calc squared distance from all samples in Xnew with all stored in self.X
        sqdists = np.sum( (self._standardizeX(Xnew)[:,np.newaxis,:] - self.X)**2, axis=-1 )
        # sqdist

        # sqdists is now n_new_samples x n_train_samples
        # Sort each row of squared distances from smallest to largest and select the first k.
        indices = np.argsort(sqdists, axis=1)[:, :k]

        # Determine mose common class label in each row.
        classes = ss.mode(self.T[indices,:][:,:,0], axis=1)[0]
        
        return classes
In [75]:
knn = KNN()
knn
Out[75]:
KNN() has not been trained.
In [76]:
knn.train(Xtrain, Ttrain)
Out[76]:
KNN(), trained with 50000 samples having class labels [0 1 2 3 4 5 6 7 8 9].

Boy, that took a long time to train! :) 200 ms.

Let's test it. First, use the default value for $k$ of 1.

In [77]:
knn.use(Xtest[:3, :])
Out[77]:
array([[7],
       [2],
       [1]])
In [78]:
Ttest[:3]
Out[78]:
array([[7],
       [2],
       [1]])

Well, that worked perfectly. Let's try more test samples.

In [79]:
knn.use(Xtest[:10, :])
Out[79]:
array([[7],
       [2],
       [1],
       [0],
       [4],
       [1],
       [4],
       [4],
       [4],
       [9]])
In [80]:
Ttest[:10]
Out[80]:
array([[7],
       [2],
       [1],
       [0],
       [4],
       [1],
       [4],
       [9],
       [5],
       [9]])

There are some mistakes. How about using more neighbors?

In [81]:
knn.use(Xtest[:10, :], k=5)
Out[81]:
array([[7],
       [2],
       [1],
       [0],
       [4],
       [1],
       [4],
       [9],
       [4],
       [9]])
In [82]:
Ttest[:10]
Out[82]:
array([[7],
       [2],
       [1],
       [0],
       [4],
       [1],
       [4],
       [9],
       [5],
       [9]])
In [83]:
def percent_correct(Predicted, T):
    return 100 * np.mean(Predicted == T)
In [84]:
percent_correct(knn.use(Xtest[:10, :], k=5), Ttest[:10])
Out[84]:
90.0

Now we can try multiple values of $k$ with a for loop, and test all test samples.

In [85]:
pc = []
for k in range(1, 21):
    pc = percent_correct(knn.use(Xtest, k=k), Ttest)
    pc.append([k, pc])
---------------------------------------------------------------------------
MemoryError                               Traceback (most recent call last)
<ipython-input-85-7f8af0cc37ab> in <module>
      1 pc = []
      2 for k in range(1, 21):
----> 3     pc = percent_correct(knn.use(Xtest, k=k), Ttest)
      4     pc.append([k, pc])

<ipython-input-74-e861359179fd> in use(self, Xnew, k)
     39         self.k = k
     40         # Calc squared distance from all samples in Xnew with all stored in self.X
---> 41         sqdists = np.sum( (self._standardizeX(Xnew)[:,np.newaxis,:] - self.X)**2, axis=-1 )
     42         # sqdist
     43 

MemoryError: Unable to allocate 1.43 TiB for an array with shape (10000, 50000, 784) and data type float32

Well, here is what we often face when dealing with big data sets. K-nearest-neighbors calculates squared distances between each train and test sample. That can get huge.

We can deal with this the typical way of working with batches of data.

In [86]:
knn = KNN()
n_train = 10000
knn.train(Xtrain[:n_train, :], Ttrain[:n_train, :])

batch_size = 200
n_samples = Xtest.shape[0]
batches = [batch_size] * (n_samples // batch_size)
if sum(batches) < n_samples:
    batches = batches.append(n_samples - sum(batches))

results = []
for k in range(1, 21):
    print(f'k={k}', end=' ')
    n_correct = 0
    first = 0
    for this_batch in batches:
        print(f'{first}', end=',')
        last = first + this_batch
        X = Xtest[first:last, :]
        T = Ttest[first:last, :]
        n_correct += np.sum(knn.use(X, k=k) == T)
        first += this_batch
    pc = n_correct / n_samples * 100
    results.append([k, pc])
    print()
k=1 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=2 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=3 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=4 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=5 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=6 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=7 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=8 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=9 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=10 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=11 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=12 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=13 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=14 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=15 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=16 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=17 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=18 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=19 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
k=20 0,200,400,600,800,1000,1200,1400,1600,1800,2000,2200,2400,2600,2800,3000,3200,3400,3600,3800,4000,4200,4400,4600,4800,5000,5200,5400,5600,5800,6000,6200,6400,6600,6800,7000,7200,7400,7600,7800,8000,8200,8400,8600,8800,9000,9200,9400,9600,9800,
In [92]:
results
Out[92]:
[[1, 91.07],
 [2, 89.59],
 [3, 91.17],
 [4, 90.95],
 [5, 91.03],
 [6, 90.86999999999999],
 [7, 90.73],
 [8, 90.66],
 [9, 90.63],
 [10, 90.48],
 [11, 90.33],
 [12, 90.32],
 [13, 90.27],
 [14, 90.16],
 [15, 90.21000000000001],
 [16, 90.09],
 [17, 89.87],
 [18, 89.71000000000001],
 [19, 89.63],
 [20, 89.55]]
In [96]:
results = np.array(results)

plt.plot(results[:, 0], results[:, 1])
plt.xlabel('$k$')
plt.ylabel('Percent Correct Test Data');

How might you change the implementation of KNN to speed up this calculation using multiple $k$ values?

How could you calculate class probabilities with KNN?

Comparison to Neural Network Classifier

In [87]:
n = 20
X = np.random.multivariate_normal([5, 7], [[0.8, -0.5], [-0.5, 0.8]], n)
X = np.vstack((X,
               np.random.multivariate_normal([6, 3], [[0.6, 0.5], [0.5, 0.8]], n)))
T = np.vstack((np.ones((n, 1)), 2 * np.ones((n, 1))))

plt.scatter(X[:, 0], X[:, 1], c=T, s=80);
In [88]:
plt.figure(figsize=(10, 10))

n = 20
X = np.random.multivariate_normal([5, 7], [[0.8, -0.5], [-0.5, 0.8]], n)
X = np.vstack((X,
               np.random.multivariate_normal([6, 3], [[0.6, 0.5], [0.5, 0.8]], n)))
T = np.vstack((np.ones((n, 1)), 2 * np.ones((n, 1))))

# Make samples as coordinates of grid points across 2-dimensional data space
m = 100
xs = np.linspace(0, 10, m)
ys = xs
Xs, Ys = np.meshgrid(xs, ys)
samples = np.vstack((Xs.ravel(), Ys.ravel())).T

knn = KNN()
knn.train(X, T)

classes = knn.use(samples, k=1)
              
plt.contourf(Xs, Ys, classes.reshape(Xs.shape), 1, colors=('blue','red'), alpha=0.2)
plt.scatter(X[:, 0], X[:, 1], s=60, c=T);
In [89]:
def plot_result(X, Xs, Ys, classes):
    plt.contourf(Xs, Ys, classes.reshape(Xs.shape), 1, colors=('blue','red'), alpha=0.2)
    plt.scatter(X[:, 0], X[:, 1], s=60, c=T);
In [90]:
import A4mysolution as nn
In [91]:
n = 40
X = np.random.multivariate_normal([5, 6], [[0.9, -0.2], [-0.2, 0.9]], n)
X = np.vstack((X,
               np.random.multivariate_normal([6, 3], [[2, 0.4], [0.4, 2]], n)))
T = np.vstack((np.ones((n, 1)), 2 * np.ones((n, 1))))


m = 100
xs = np.linspace(1, 9, m)           
ys = xs
Xs,Ys = np.meshgrid(xs, ys)
samples = np.vstack((Xs.ravel(), Ys.ravel())).T

plt.figure(figsize=(20, 30))

knn = KNN()
knn.train(X, T)
    
ploti = 0
for k in [1, 2, 3, 5, 10, 20]:
    ploti += 1
    plt.subplot(4, 3, ploti)

    classes = knn.use(samples, k)
    plot_result(X, Xs, Ys, classes)
    plt.title(f'KNN k={k}')

for n_hiddens in [[], [1], [2], [10], [10, 10], [5, 5, 5, 5]]:
    ploti += 1
    plt.subplot(4, 3, ploti)

    nnet = nn.NeuralNetworkClassifier(2, n_hiddens, 2)
    nnet.train(X, T, n_epochs=500, verbose=False)
    classes, _ = nnet.use(samples)
    plot_result(X, Xs, Ys, classes)
    plt.title(f'nnet {n_hiddens}')