K-Means Clustering

Let's check out the Stone Flakes data set. Find and download the StoneFlakes.dat file by clicking here or by running wget like below.

In [1]:
!wget http://archive.ics.uci.edu/ml/machine-learning-databases/00299/StoneFlakes.dat
--2019-04-25 14:59:09--  http://archive.ics.uci.edu/ml/machine-learning-databases/00299/StoneFlakes.dat
Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252
Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3523 (3.4K) [application/x-httpd-php]
Saving to: ‘StoneFlakes.dat’

StoneFlakes.dat     100%[===================>]   3.44K  --.-KB/s    in 0s      

2019-04-25 14:59:09 (180 MB/s) - ‘StoneFlakes.dat’ saved [3523/3523]

Let's look at the first few lines.

In [2]:
!head StoneFlakes.dat








Read about the column names and the meaning of the ID values at the data set's web site.

Notice that values are separated by commas, except for the first column. Also notice that there are question marks where data is missing. How can we read this? Well, the usual answer is to "google" for the answer. Try seaching for "read data set numpy"

Here we will use the pandas package to deal with these issues.

In [3]:
import pandas
In [4]:
d = pandas.read_csv(open('StoneFlakes.dat'))
In [5]:
d[:5]
Out[5]:
ID LBI RTI WDI FLA PSF FSF ZDF1 PROZD
ar ? 35.3 2.60 ? 42.4 24.2 47.1 69
arn 1.23 27.0 3.59 122 0.0 40.0 40.0 30
be 1.24 26.5 2.90 121 16.0 20.7 29.7 72
bi1 1.07 29.1 3.10 114 44.0 2.6 26.3 68
bi2 1.08 43.7 2.40 105 32.6 5.8 10.7 42
In [6]:
d = pandas.read_csv(open('StoneFlakes.dat'), sep=',')
d[:5]
Out[6]:
ID LBI RTI WDI FLA PSF FSF ZDF1 PROZD
ar ? 35.3 2.60 ? 42.4 24.2 47.1 69
arn 1.23 27.0 3.59 122 0.0 40.0 40.0 30
be 1.24 26.5 2.90 121 16.0 20.7 29.7 72
bi1 1.07 29.1 3.10 114 44.0 2.6 26.3 68
bi2 1.08 43.7 2.40 105 32.6 5.8 10.7 42

Let's just replace commas with spaces, using unix. Read about the tr unix command at Linux TR Command Examples

In [7]:
! tr -s ' ' ',' < StoneFlakes.dat > StoneFlakes2.dat
! head StoneFlakes2.dat








In [8]:
d = pandas.read_csv(open('StoneFlakes2.dat'))
d[:5]
Out[8]:
ID LBI RTI WDI FLA PSF FSF ZDF1 PROZD
0 ar ? 35.3 2.60 ? 42.4 24.2 47.1 69
1 arn 1.23 27.0 3.59 122 0.0 40.0 40.0 30
2 be 1.24 26.5 2.90 121 16.0 20.7 29.7 72
3 bi1 1.07 29.1 3.10 114 44.0 2.6 26.3 68
4 bi2 1.08 43.7 2.40 105 32.6 5.8 10.7 42
In [9]:
d = pandas.read_csv(open('StoneFlakes2.dat'), na_values='?')
d[:5]
Out[9]:
ID LBI RTI WDI FLA PSF FSF ZDF1 PROZD
0 ar NaN 35.3 2.60 NaN 42.4 24.2 47.1 69
1 arn 1.23 27.0 3.59 122.0 0.0 40.0 40.0 30
2 be 1.24 26.5 2.90 121.0 16.0 20.7 29.7 72
3 bi1 1.07 29.1 3.10 114.0 44.0 2.6 26.3 68
4 bi2 1.08 43.7 2.40 105.0 32.6 5.8 10.7 42
In [10]:
d = pandas.read_csv(open('StoneFlakes2.dat'), na_values='?', error_bad_lines=False)
d[:5]
Out[10]:
ID LBI RTI WDI FLA PSF FSF ZDF1 PROZD
0 ar NaN 35.3 2.60 NaN 42.4 24.2 47.1 69
1 arn 1.23 27.0 3.59 122.0 0.0 40.0 40.0 30
2 be 1.24 26.5 2.90 121.0 16.0 20.7 29.7 72
3 bi1 1.07 29.1 3.10 114.0 44.0 2.6 26.3 68
4 bi2 1.08 43.7 2.40 105.0 32.6 5.8 10.7 42
In [11]:
d[:5].isnull()
Out[11]:
ID LBI RTI WDI FLA PSF FSF ZDF1 PROZD
0 False True False False True False False False False
1 False False False False False False False False False
2 False False False False False False False False False
3 False False False False False False False False False
4 False False False False False False False False False
In [12]:
d[:5].isnull().any(axis=1)
Out[12]:
0     True
1    False
2    False
3    False
4    False
dtype: bool
In [13]:
d[:5].isnull().any(axis=1) == False
Out[13]:
0    False
1     True
2     True
3     True
4     True
dtype: bool
In [14]:
print(d.shape)
d = d[d.isnull().any(axis=1)==False]
print(d.shape)
(79, 9)
(73, 9)
In [15]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
In [16]:
data = d.iloc[:,1:].values
data.shape
Out[16]:
(73, 8)
In [17]:
data[:5,:]
Out[17]:
array([[  1.23,  27.  ,   3.59, 122.  ,   0.  ,  40.  ,  40.  ,  30.  ],
       [  1.24,  26.5 ,   2.9 , 121.  ,  16.  ,  20.7 ,  29.7 ,  72.  ],
       [  1.07,  29.1 ,   3.1 , 114.  ,  44.  ,   2.6 ,  26.3 ,  68.  ],
       [  1.08,  43.7 ,   2.4 , 105.  ,  32.6 ,   5.8 ,  10.7 ,  42.  ],
       [  1.39,  29.5 ,   2.78, 126.  ,  14.  ,   0.  ,  50.  ,  78.  ]])

To see this data, let's try plotting each column as a separate curve on the same axes.

In [18]:
plt.plot(data);

Each sample has 8 attributes, so each sample is a point in 8-dimensional space. I wonder how well the samples "clump" in those 8 dimensions. Let's try clustering them with the k-means algorithm.

First, let's try to find two clusters, so $k=2$. We must initialize the two means of the two clusters. Let's just pick two samples at random.

In [19]:
np.random.choice(range(data.shape[0]),2, replace=False) # data.shape[0] is number of rows, or samples
Out[19]:
array([18, 36])
In [20]:
np.random.choice(range(data.shape[0]),2, replace=False)
Out[20]:
array([70, 49])
In [21]:
centersIndex = np.random.choice(range(data.shape[0]),2, replace=False)
centersIndex
Out[21]:
array([31,  2])
In [22]:
centers = data[centersIndex,:]
centers
Out[22]:
array([[  1.21,  23.3 ,   3.9 , 110.  ,   1.6 ,  35.5 ,  94.1 ,  95.  ],
       [  1.07,  29.1 ,   3.1 , 114.  ,  44.  ,   2.6 ,  26.3 ,  68.  ]])

Now we must find all samples that are closest to the first center, and those that are closest to the second sample.

In [23]:
a = np.array([1,2,3])
b = np.array([10,20,30])
a, b
Out[23]:
(array([1, 2, 3]), array([10, 20, 30]))
In [24]:
a-b
Out[24]:
array([ -9, -18, -27])

But what if we want to subtract every element of a with every element of b?

In [25]:
np.resize(a,(3,3))
Out[25]:
array([[1, 2, 3],
       [1, 2, 3],
       [1, 2, 3]])
In [26]:
np.resize(b, (3,3))
Out[26]:
array([[10, 20, 30],
       [10, 20, 30],
       [10, 20, 30]])
In [27]:
np.resize(a,(3,3)).T
Out[27]:
array([[1, 1, 1],
       [2, 2, 2],
       [3, 3, 3]])
In [28]:
np.resize(a,(3,3)).T - np.resize(b,(3,3))
Out[28]:
array([[ -9, -19, -29],
       [ -8, -18, -28],
       [ -7, -17, -27]])

However, we can ask numpy to do this duplication for us if we reshape a to be a column vector and leave b as a row vector.

$$ \begin{pmatrix} 1\\ 2\\ 3 \end{pmatrix} - \begin{pmatrix} 10 & 20 & 30 \end{pmatrix} \;\; = \;\; \begin{pmatrix} 1 & 1 & 1\\ 2 & 2 & 2\\ 3 & 3 & 3 \end{pmatrix} - \begin{pmatrix} 10 & 20 & 30\\ 10 & 20 & 30\\ 10 & 20 & 30 \end{pmatrix} $$
In [29]:
a = a.reshape((-1,1))
a
Out[29]:
array([[1],
       [2],
       [3]])
In [30]:
a - b
Out[30]:
array([[ -9, -19, -29],
       [ -8, -18, -28],
       [ -7, -17, -27]])
In [31]:
a = np.array([1,2,3])
b = np.array([[10,20,30],[40,50,60]])
print(a)
print(b)
[1 2 3]
[[10 20 30]
 [40 50 60]]
In [32]:
b-a
Out[32]:
array([[ 9, 18, 27],
       [39, 48, 57]])

The single row vector a is duplicated for as many rows as there are in b! We can use this to calculate the squared distance between a center and every sample.

In [33]:
centers[0,:]
Out[33]:
array([  1.21,  23.3 ,   3.9 , 110.  ,   1.6 ,  35.5 ,  94.1 ,  95.  ])
In [34]:
np.sum((centers[0,:] - data)**2, axis=1)
Out[34]:
array([ 7332.4065,  5235.0009,  8256.3096, 12051.0769,  3943.5468,
        5270.79  ,  4587.1536,  8387.04  ,  8514.841 ,  9991.3525,
        6339.5253,  1478.7909,  5474.2501,  4741.9944, 15881.0021,
        1562.1049,  1159.3125,  2328.4601,  1469.5104,  2961.5116,
        1525.7004,   183.24  ,  1158.7321,  2103.8749, 10376.8525,
        7679.526 ,  5631.84  ,  6094.28  ,  6037.1525,   723.8609,
         163.45  ,     0.    ,   624.9065,  4490.7444,  8351.9097,
        8590.9589,  6329.6436, 12702.2969,  6968.2309,  6479.5601,
        2349.2509,  1289.6581,   810.1481,   700.5381,  5537.8345,
        7283.9726,   628.5424,   428.3436,   378.7449,   469.2709,
        1373.7622,  1941.9369,  1618.1716, 13650.1561,   665.4716,
         676.2249,  1289.6901,  1790.7189,  2833.7069,   551.1824,
        8330.0989, 11030.5916,  4378.158 ,  2769.4949,  3300.6525,
       11240.3225,  3018.4445,  6242.8469,  7228.7389, 11270.9338,
        5532.2509,  7979.2461,  5996.028 ])
In [35]:
np.sum((centers[1,:] - data)**2, axis=1) > np.sum((centers[0,:] - data)**2, axis=1)
Out[35]:
array([False, False, False, False, False, False, False, False, False,
       False, False,  True, False, False, False,  True,  True,  True,
        True,  True,  True,  True,  True,  True, False, False, False,
       False, False,  True,  True,  True,  True, False, False, False,
       False, False, False, False, False,  True,  True,  True, False,
       False,  True,  True,  True,  True,  True,  True,  True, False,
        True,  True,  True,  True, False,  True, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False])
In [36]:
centers
Out[36]:
array([[  1.21,  23.3 ,   3.9 , 110.  ,   1.6 ,  35.5 ,  94.1 ,  95.  ],
       [  1.07,  29.1 ,   3.1 , 114.  ,  44.  ,   2.6 ,  26.3 ,  68.  ]])
In [37]:
centers[:,np.newaxis,:].shape, data.shape
Out[37]:
((2, 1, 8), (73, 8))
In [38]:
(centers[:,np.newaxis,:] - data).shape
Out[38]:
(2, 73, 8)
In [39]:
np.sum((centers[:,np.newaxis,:] - data)**2, axis=2).shape
Out[39]:
(2, 73)
In [40]:
np.argmin(np.sum((centers[:,np.newaxis,:] - data)**2, axis=2), axis=0)
Out[40]:
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
       1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1])
In [41]:
cluster = np.argmin(np.sum((centers[:,np.newaxis,:] - data)**2, axis=2), axis=0)
cluster
Out[41]:
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
       1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1])
In [42]:
data[cluster==0,:].mean(axis=0)
Out[42]:
array([  1.22241379,  24.63758621,   3.40344828, 114.17241379,
         5.17241379,  27.41724138,  68.14482759,  88.        ])
In [43]:
data[cluster==1,:].mean(axis=0)
Out[43]:
array([  1.19818182,  30.32954545,   2.87340909, 119.52272727,
        23.14772727,  11.83636364,  27.325     ,  65.86363636])
In [44]:
k = 2
for i in range(k):
    centers[i,:] = data[cluster==i,:].mean(axis=0)
In [45]:
centers
Out[45]:
array([[  1.22241379,  24.63758621,   3.40344828, 114.17241379,
          5.17241379,  27.41724138,  68.14482759,  88.        ],
       [  1.19818182,  30.32954545,   2.87340909, 119.52272727,
         23.14772727,  11.83636364,  27.325     ,  65.86363636]])
In [46]:
def kmeans(data, k = 2, n = 5):
    # Initial centers
    centers = data[np.random.choice(range(data.shape[0]), k, replace=False), :]
    # Repeat n times
    for iteration in range(n):
        # Which center is each sample closest to?
        closest = np.argmin(np.sum((centers[:, np.newaxis, :] - data)**2, axis=2), axis=0)
        # Update cluster centers
        for i in range(k):
            centers[i, :] = data[closest==i, :].mean(axis=0)
    return centers
In [47]:
kmeans(data,2)
Out[47]:
array([[  1.23194444,  25.36083333,   3.36833333, 114.47222222,
          7.20833333,  23.78333333,  64.91944444,  86.58333333],
       [  1.18432432,  30.7027027 ,   2.8072973 , 120.24324324,
         24.56756757,  12.42432432,  22.74054054,  63.05405405]])
In [48]:
kmeans(data,2)
Out[48]:
array([[  1.18432432,  30.7027027 ,   2.8072973 , 120.24324324,
         24.56756757,  12.42432432,  22.74054054,  63.05405405],
       [  1.23194444,  25.36083333,   3.36833333, 114.47222222,
          7.20833333,  23.78333333,  64.91944444,  86.58333333]])

Let's define $J$ as the performance measure being minimized by k-means. $$ J = \sum_{n=1}^N \sum_{k=1}^K r_{nk} ||\mathbf{x}_n - \mathbf{\mu}_k||^2 $$ where $N$ is the number of samples, $K$ is the number of cluster centers, $\mathbf{x}_n$ is the $n^{th}$ sample and $\mathbf{\mu}_k$ is the $k^{th}$ center, each being an element of $\mathbf{R}^p$ where $p$ is the dimensionality of the data.

The sums can be computed using python for loops, but for loops are much slower than matrix operations in python, as the following cells show.

In [49]:
a = np.linspace(0,10,30).reshape(3,10)
a
Out[49]:
array([[ 0.        ,  0.34482759,  0.68965517,  1.03448276,  1.37931034,
         1.72413793,  2.06896552,  2.4137931 ,  2.75862069,  3.10344828],
       [ 3.44827586,  3.79310345,  4.13793103,  4.48275862,  4.82758621,
         5.17241379,  5.51724138,  5.86206897,  6.20689655,  6.55172414],
       [ 6.89655172,  7.24137931,  7.5862069 ,  7.93103448,  8.27586207,
         8.62068966,  8.96551724,  9.31034483,  9.65517241, 10.        ]])
In [50]:
b = np.arange(30).reshape(3,10)
b
Out[50]:
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]])
In [51]:
result = np.zeros((3,10))
for i in range(3):
    for j in range(10):
        result[i,j] = a[i,j] + b[i,j]
result
Out[51]:
array([[ 0.        ,  1.34482759,  2.68965517,  4.03448276,  5.37931034,
         6.72413793,  8.06896552,  9.4137931 , 10.75862069, 12.10344828],
       [13.44827586, 14.79310345, 16.13793103, 17.48275862, 18.82758621,
        20.17241379, 21.51724138, 22.86206897, 24.20689655, 25.55172414],
       [26.89655172, 28.24137931, 29.5862069 , 30.93103448, 32.27586207,
        33.62068966, 34.96551724, 36.31034483, 37.65517241, 39.        ]])
In [52]:
a.shape, a[:,np.newaxis].shape
Out[52]:
((3, 10), (3, 1, 10))
In [53]:
%%timeit
result = np.zeros((3,10))
for i in range(3):
    for j in range(10):
        result[i,j] = a[i,j] + b[i,j]
17.8 µs ± 234 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [54]:
result = a + b
result
Out[54]:
array([[ 0.        ,  1.34482759,  2.68965517,  4.03448276,  5.37931034,
         6.72413793,  8.06896552,  9.4137931 , 10.75862069, 12.10344828],
       [13.44827586, 14.79310345, 16.13793103, 17.48275862, 18.82758621,
        20.17241379, 21.51724138, 22.86206897, 24.20689655, 25.55172414],
       [26.89655172, 28.24137931, 29.5862069 , 30.93103448, 32.27586207,
        33.62068966, 34.96551724, 36.31034483, 37.65517241, 39.        ]])
In [55]:
%%timeit
result = a + b
1.09 µs ± 16.5 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

So, the matrix form is 10 times faster!

Now, back to our problem. How can we use matrix operations to calculate the squared distance between two centers and, say, five data samples? Let's say both are two-dimensional.

In [56]:
centers = np.array([[1,2],[5,4]])
centers
Out[56]:
array([[1, 2],
       [5, 4]])
In [57]:
data = np.array([[3,2],[4,6],[7,3],[4,6],[1,8]])
data
Out[57]:
array([[3, 2],
       [4, 6],
       [7, 3],
       [4, 6],
       [1, 8]])

This will be a little weird, and hard to understand, but by adding an empty dimension to the centers array, numpy broadcasting does all the work for us.

In [58]:
centers[:,np.newaxis,:]
Out[58]:
array([[[1, 2]],

       [[5, 4]]])
In [59]:
centers[:,np.newaxis,:].shape
Out[59]:
(2, 1, 2)
In [60]:
data.shape
Out[60]:
(5, 2)
In [61]:
diffsq = (centers[:,np.newaxis,:] - data)**2
diffsq
Out[61]:
array([[[ 4,  0],
        [ 9, 16],
        [36,  1],
        [ 9, 16],
        [ 0, 36]],

       [[ 4,  4],
        [ 1,  4],
        [ 4,  1],
        [ 1,  4],
        [16, 16]]])
In [62]:
diffsq.shape
Out[62]:
(2, 5, 2)
In [63]:
np.sum(diffsq,axis=2)
Out[63]:
array([[ 4, 25, 37, 25, 36],
       [ 8,  5,  5,  5, 32]])

Now we have a 2 x 5 array with the first row containing the squared distance from the first center to each of the five data samples, and the second row containing the squared distances from the second center to each of the five data samples.

Now we just have to find the smallest distance in each column and sum them up.

In [64]:
np.min(np.sum(diffsq,axis=2), axis=0)
Out[64]:
array([ 4,  5,  5,  5, 32])
In [65]:
np.sum(np.min(np.sum(diffsq,axis=2), axis=0))
Out[65]:
51

Let's define a function named calcJ to do this calculation.

In [66]:
def calcJ(data,centers):
    diffsq = (centers[:,np.newaxis,:] - data)**2
    return np.sum(np.min(np.sum(diffsq,axis=2), axis=0))
In [67]:
calcJ(data,centers)
Out[67]:
51
In [68]:
def kmeans(data, k = 2, n = 5):
    # Initialize centers and list J to track performance metric
    centers = data[np.random.choice(range(data.shape[0]),k,replace=False), :]
    J = []
    
    # Repeat n times
    for iteration in range(n):
        
        # Which center is each sample closest to?
        sqdistances = np.sum((centers[:,np.newaxis,:] - data)**2, axis=2)
        closest = np.argmin(sqdistances, axis=0)
        
        # Calculate J and append to list J
        J.append(calcJ(data,centers))
        
        # Update cluster centers
        for i in range(k):
            centers[i,:] = data[closest==i,:].mean(axis=0)
            
    # Calculate J one final time and return results
    J.append(calcJ(data,centers))
    return centers, J, closest
In [69]:
centers,J,closest = kmeans(data,2)
In [70]:
J
Out[70]:
[30, 19, 19, 19, 19, 19]
In [71]:
plt.plot(J);
In [73]:
centers,J,closest = kmeans(data, 2, 10)
plt.plot(J);
In [75]:
centers,J,closest = kmeans(data, 3, 10)
plt.plot(J);
/s/parsons/e/fac/anderson/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:18: RuntimeWarning: Mean of empty slice.
In [76]:
small = np.array([[8,7],[7,6.6],[9.2,8.3],[6.8,9.2], [1.2,3.2],[4.8,2.3],[3.4,3.2],[3.2,5.6],[1,4],[2,2.2]])
In [77]:
plt.scatter(small[:,0],small[:,1]);
In [78]:
c,J,closest = kmeans(small,2,n=2)
In [79]:
c
Out[79]:
array([[2.48, 2.98],
       [6.84, 7.34]])
In [80]:
closest
Out[80]:
array([1, 1, 1, 1, 0, 0, 0, 1, 0, 0])
In [81]:
plt.scatter(small[:,0], small[:,1], s=80, c=closest, alpha=0.5);
plt.scatter(c[:,0],c[:,1],s=80,c="green",alpha=0.5);
In [82]:
c,J,closest = kmeans(small,2,n=2)
plt.scatter(small[:,0], small[:,1], s=80, c=closest, alpha=0.5);
plt.scatter(c[:,0],c[:,1],s=80,c="green",alpha=0.5);
In [83]:
J
Out[83]:
[48.12, 26.565833333333334, 26.565833333333334]
In [84]:
import gzip
import pickle

with gzip.open('mnist.pkl.gz', 'rb') as f:
    train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
    # zero = train_set[0][1,:].reshape((28,28,1))
    # one = train_set[0][3,:].reshape((28,28,1))
    # two = train_set[0][5,:].reshape((28,28,1))
    # four = train_set[0][20,:].reshape((28,28,1))

X = train_set[0]
T = train_set[1].reshape((-1,1))

Xtest = test_set[0]
Ttest = test_set[1].reshape((-1,1))

X.shape, T.shape, Xtest.shape, Ttest.shape
Out[84]:
((50000, 784), (50000, 1), (10000, 784), (10000, 1))
In [85]:
c,J,closest = kmeans(X, k=10, n=20)
In [86]:
plt.plot(J)
Out[86]:
[<matplotlib.lines.Line2D at 0x7fb74869d710>]
In [87]:
c.shape
Out[87]:
(10, 784)
In [88]:
for i in range(10):
    plt.subplot(2,5,i+1)
    plt.imshow(-c[i,:].reshape((28,28)), interpolation='nearest', cmap='gray')
    plt.axis('off')
In [89]:
c, J, closest = kmeans(X, k=10, n=20)
plt.plot(J)
plt.figure()
for i in range(10):
    plt.subplot(2, 5, i+1)
    plt.imshow(-c[i, :].reshape((28, 28)), interpolation='nearest', cmap='gray')
    plt.axis('off')
In [90]:
c,J,closest = kmeans(X, k=20, n=20)
plt.plot(J)
plt.figure()
for i in range(20):
    plt.subplot(4,5,i+1)
    plt.imshow(-c[i,:].reshape((28,28)), interpolation='nearest', cmap='gray')
    plt.axis('off')
In [91]:
c,J,closest = kmeans(X, k=20, n=30)
plt.plot(J)
plt.figure()
for i in range(20):
    plt.subplot(4,5,i+1)
    plt.imshow(-c[i,:].reshape((28,28)), interpolation='nearest', cmap='gray')
    plt.axis('off')
In [ ]: