import numpy as np
import pandas as pd
from collections import Counter
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
from IPython import display
K Nearest Neighbours is a algorithim for finding out the similarity or distance b/w two things, to find out how alike/different they are.
Say we have a bunch of fruit, KNN will classify them into clusters by using what we know - with fruit this would be shape, size, weight, color, etc.
Anyways, lets start with the Iris dataset, which has 150 measurements of flowers:
iris = sns.load_dataset("iris")
print(f"Iris dataset shape: {iris.shape}")
iris.head()
Iris dataset shape: (150, 5)
sepal_length | sepal_width | petal_length | petal_width | species | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | setosa |
1 | 4.9 | 3.0 | 1.4 | 0.2 | setosa |
2 | 4.7 | 3.2 | 1.3 | 0.2 | setosa |
3 | 4.6 | 3.1 | 1.5 | 0.2 | setosa |
4 | 5.0 | 3.6 | 1.4 | 0.2 | setosa |
Now I'm sampling 5 random flowers from this data set so we can use our fancy new KNN algo to determine what kind of flower they are later on:
test = iris.sample(n=5)
test
sepal_length | sepal_width | petal_length | petal_width | species | |
---|---|---|---|---|---|
97 | 6.2 | 2.9 | 4.3 | 1.3 | versicolor |
95 | 5.7 | 3.0 | 4.2 | 1.2 | versicolor |
50 | 7.0 | 3.2 | 4.7 | 1.4 | versicolor |
140 | 6.7 | 3.1 | 5.6 | 2.4 | virginica |
84 | 5.4 | 3.0 | 4.5 | 1.5 | versicolor |
And here I am deleting the sampled flowers from the iris dataset to make sure our algo hasn't seem the test flowers:
iris.drop(test.index, inplace=True)
print(iris.shape)
iris.head()
(145, 5)
sepal_length | sepal_width | petal_length | petal_width | species | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | setosa |
1 | 4.9 | 3.0 | 1.4 | 0.2 | setosa |
2 | 4.7 | 3.2 | 1.3 | 0.2 | setosa |
3 | 4.6 | 3.1 | 1.5 | 0.2 | setosa |
4 | 5.0 | 3.6 | 1.4 | 0.2 | setosa |
Now to look at the data visually:
It's pretty clear the the species are different, though there is some overlap at the boundaries:
sns.pairplot(data=iris, hue="species")
<seaborn.axisgrid.PairGrid at 0x7f913be37e48>
Looking at petal length variation across species:
sns.boxplot(x="species", y="petal_length", data=iris);
Now to actually write the algorithim and figure out what species the flowers in the test data set belong to.
First, a helper function to calculate the distance b/w points:
def distance(x, y):
"""returns distance b/w two points x and y"""
assert len(x) == len(y)
inner = 0
for a, b in zip(x,y):
inner += (a - b)**2
return np.sqrt(inner)
distance((1,5),[5,5])
4.0
lets look at the values of the first flower in our test data and see if we can figure out what it is by using KNN:
test.iloc[2]
sepal_length 7 sepal_width 3.2 petal_length 4.7 petal_width 1.4 species versicolor knn versicolor Name: 50, dtype: object
def knn(item, data, n=3):
"""takes in an item to check and a dataset, of size 4 features each
returns the first n closest neighbours as a tuple (loc, distance)"""
dist = []
for i, row in data.iterrows():
dist.append((i, distance(row[:4], item)))
nearest = sorted(dist, key=lambda x: x[1])[:n]
species = [iris.loc[i[0]]["species"] for i in nearest]
return Counter(species).most_common()[0][0]
knn(test.iloc[2][:4], iris)
'versicolor'
knn_species = []
for i, row in test.iterrows():
knn_species.append(knn(row[:4], iris))
knn_species
['versicolor', 'versicolor', 'versicolor', 'virginica', 'versicolor']
test["knn"] = knn_species
test
sepal_length | sepal_width | petal_length | petal_width | species | knn | |
---|---|---|---|---|---|---|
97 | 6.2 | 2.9 | 4.3 | 1.3 | versicolor | versicolor |
95 | 5.7 | 3.0 | 4.2 | 1.2 | versicolor | versicolor |
50 | 7.0 | 3.2 | 4.7 | 1.4 | versicolor | versicolor |
140 | 6.7 | 3.1 | 5.6 | 2.4 | virginica | virginica |
84 | 5.4 | 3.0 | 4.5 | 1.5 | versicolor | versicolor |
All right! our KNN algo got all the 5 test species right!
I wrote this KNN algo a bit specifically targeting the iris dataset, but it can be modified to deal with with many different data sets.
The iris dataset is very simple, but usually I would normalize the data so all the attributes get a chance to effect the rating.