k-Nearest Neighbors (kNN)

1. kNN 이란?

2. 데이터 다루기 (Data Handling)

1) 데이터 로딩하기

  • Iris (붓꽃) 데이터 로딩
In [1]:
import urllib2
import json
from scipy import stats
from pandas import Series, DataFrame
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

path = 'https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data'
raw_csv = urllib2.urlopen(path)
feature_names = ('sepal length', 'sepal width', 'petal length', 'petal width')
all_names = feature_names + ('class',)
df = pd.read_csv(raw_csv, names=all_names)
print df
     sepal length  sepal width  petal length  petal width           class
0             5.1          3.5           1.4          0.2     Iris-setosa
1             4.9          3.0           1.4          0.2     Iris-setosa
2             4.7          3.2           1.3          0.2     Iris-setosa
3             4.6          3.1           1.5          0.2     Iris-setosa
4             5.0          3.6           1.4          0.2     Iris-setosa
5             5.4          3.9           1.7          0.4     Iris-setosa
6             4.6          3.4           1.4          0.3     Iris-setosa
7             5.0          3.4           1.5          0.2     Iris-setosa
8             4.4          2.9           1.4          0.2     Iris-setosa
9             4.9          3.1           1.5          0.1     Iris-setosa
10            5.4          3.7           1.5          0.2     Iris-setosa
11            4.8          3.4           1.6          0.2     Iris-setosa
12            4.8          3.0           1.4          0.1     Iris-setosa
13            4.3          3.0           1.1          0.1     Iris-setosa
14            5.8          4.0           1.2          0.2     Iris-setosa
15            5.7          4.4           1.5          0.4     Iris-setosa
16            5.4          3.9           1.3          0.4     Iris-setosa
17            5.1          3.5           1.4          0.3     Iris-setosa
18            5.7          3.8           1.7          0.3     Iris-setosa
19            5.1          3.8           1.5          0.3     Iris-setosa
20            5.4          3.4           1.7          0.2     Iris-setosa
21            5.1          3.7           1.5          0.4     Iris-setosa
22            4.6          3.6           1.0          0.2     Iris-setosa
23            5.1          3.3           1.7          0.5     Iris-setosa
24            4.8          3.4           1.9          0.2     Iris-setosa
25            5.0          3.0           1.6          0.2     Iris-setosa
26            5.0          3.4           1.6          0.4     Iris-setosa
27            5.2          3.5           1.5          0.2     Iris-setosa
28            5.2          3.4           1.4          0.2     Iris-setosa
29            4.7          3.2           1.6          0.2     Iris-setosa
..            ...          ...           ...          ...             ...
120           6.9          3.2           5.7          2.3  Iris-virginica
121           5.6          2.8           4.9          2.0  Iris-virginica
122           7.7          2.8           6.7          2.0  Iris-virginica
123           6.3          2.7           4.9          1.8  Iris-virginica
124           6.7          3.3           5.7          2.1  Iris-virginica
125           7.2          3.2           6.0          1.8  Iris-virginica
126           6.2          2.8           4.8          1.8  Iris-virginica
127           6.1          3.0           4.9          1.8  Iris-virginica
128           6.4          2.8           5.6          2.1  Iris-virginica
129           7.2          3.0           5.8          1.6  Iris-virginica
130           7.4          2.8           6.1          1.9  Iris-virginica
131           7.9          3.8           6.4          2.0  Iris-virginica
132           6.4          2.8           5.6          2.2  Iris-virginica
133           6.3          2.8           5.1          1.5  Iris-virginica
134           6.1          2.6           5.6          1.4  Iris-virginica
135           7.7          3.0           6.1          2.3  Iris-virginica
136           6.3          3.4           5.6          2.4  Iris-virginica
137           6.4          3.1           5.5          1.8  Iris-virginica
138           6.0          3.0           4.8          1.8  Iris-virginica
139           6.9          3.1           5.4          2.1  Iris-virginica
140           6.7          3.1           5.6          2.4  Iris-virginica
141           6.9          3.1           5.1          2.3  Iris-virginica
142           5.8          2.7           5.1          1.9  Iris-virginica
143           6.8          3.2           5.9          2.3  Iris-virginica
144           6.7          3.3           5.7          2.5  Iris-virginica
145           6.7          3.0           5.2          2.3  Iris-virginica
146           6.3          2.5           5.0          1.9  Iris-virginica
147           6.5          3.0           5.2          2.0  Iris-virginica
148           6.2          3.4           5.4          2.3  Iris-virginica
149           5.9          3.0           5.1          1.8  Iris-virginica

[150 rows x 5 columns]

2) 학습 데이터와 검증 데이터 분리하기

In [10]:
import random

def splitDataset(split, df, training_set=[], test_set=[]):
    for i in range(len(df)):
        if random.random() < split:
            training_set.append(df.ix[i])
        else:
            test_set.append(df.ix[i])
    return training_set, test_set        

split = 0.66
training_set, test_set = splitDataset(split, df)
print 'Train: ' + str(len(training_set)) + " - ratio: " + str(float(len(training_set))/len(df))
print 'Test: ' + str(len(test_set)) + " - ratio: " + str(float(len(test_set))/len(df))
Train: 98 - ratio: 0.653333333333
Test: 52 - ratio: 0.346666666667

2. 유사도 (Similarity) 정의

In [11]:
num_feature = len(feature_names)

import math
def euclideanDistance(instance1, instance2):
    distance = 0
    for x in range(num_feature):
        distance += pow((instance1[x] - instance2[x]), 2)
    return math.sqrt(distance)

df_feature = df.drop('class', axis=1)
print df_feature.head()
print 

distance = euclideanDistance(df_feature.ix[0], df_feature.ix[1])
print 'Distance: ' + str(distance)
   sepal length  sepal width  petal length  petal width
0           5.1          3.5           1.4          0.2
1           4.9          3.0           1.4          0.2
2           4.7          3.2           1.3          0.2
3           4.6          3.1           1.5          0.2
4           5.0          3.6           1.4          0.2

Distance: 0.538516480713

3. k-이웃 (k-Neighbors) 찾기

  • 테스트 집합내 임의의 인스턴스(test_instance)에 대하여 훈련 데이터 집합(training_set)내에서 유사도가 높은 k개의 인스턴스 찾기
In [14]:
import operator 
def getNeighbors(training_set, test_instance, k):
    distances = []
    for i in range(len(training_set)):
        dist = euclideanDistance(training_set[i], test_instance)
        distances.append((training_set[i], dist))
    distances.sort(key=operator.itemgetter(1))
    neighbors = []
    for i in range(k):
        neighbors.append(distances[i][0])
    return neighbors

print test_set[0]
print 

k = 1
neighbors = getNeighbors(training_set, test_set[0], k)
print neighbors
sepal length            4.9
sepal width               3
petal length            1.4
petal width             0.2
class           Iris-setosa
Name: 1, dtype: object

[sepal length            4.8
sepal width               3
petal length            1.4
petal width             0.1
class           Iris-setosa
Name: 12, dtype: object]
In [19]:
print neighbors[0]
print
print type(neighbors[0])
print
print neighbors[0][-1]
sepal length            4.8
sepal width               3
petal length            1.4
petal width             0.1
class           Iris-setosa
Name: 12, dtype: object

<class 'pandas.core.series.Series'>

Iris-setosa
In [20]:
k = 3
neighbors = getNeighbors(training_set, test_set[0], k)
print(neighbors)
[sepal length            4.8
sepal width               3
petal length            1.4
petal width             0.1
class           Iris-setosa
Name: 12, dtype: object, sepal length            4.9
sepal width             3.1
petal length            1.5
petal width             0.1
class           Iris-setosa
Name: 9, dtype: object, sepal length            4.9
sepal width             3.1
petal length            1.5
petal width             0.1
class           Iris-setosa
Name: 34, dtype: object]

4. 분류하기 (Classify)

  • 테스트 집합(test_set)내 임의의 인스턴스(test_instance)에 대하여...
  • 훈련 데이터 집합(training_set)내에서 유사도가 높은 k개의 인스턴스의 분류 중 가장 빈도수가 높은 분류를 해당 인스턴스(test_instance)의 분류로 정하기
In [19]:
def classify(neighbors):
    class_frequency = {}
    for i in range(len(neighbors)):
        class_name = neighbors[i][-1]
        if class_name in class_frequency:
            class_frequency[class_name] += 1
        else:
            class_frequency[class_name] = 1
    sorted_class_frequency = sorted(class_frequency.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sorted_class_frequency[0][0]

k = 3
neighbors = getNeighbors(training_set, test_set[0], k)

classified_class_name = classify(neighbors)
print "Classified:", classified_class_name, "- Actual:", test_set[0]['class']
Classified: Iris-setosa - Actual: Iris-setosa

5. 전체 테스트 집합에 대해 분류 및 정확도 (Accuracy) 평가

In [21]:
k = 3
classified_class_names=[]
for i in range(len(test_set)):
    neighbors = getNeighbors(training_set, test_set[i], k)
    result = classify(neighbors)
    classified_class_names.append(result)
    print('Classified:' + result + ', Actual:' + test_set[i][-1])

correct = 0.0
for i in range(len(test_set)):
    if classified_class_names[i] == test_set[i][-1]:
        correct += 1.0

print
print('Accuracy: ' + str(correct / float(len(test_set)) * 100.0) + '%')
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-setosa, Actual:Iris-setosa
Classified:Iris-versicolor, Actual:Iris-versicolor
Classified:Iris-versicolor, Actual:Iris-versicolor
Classified:Iris-versicolor, Actual:Iris-versicolor
Classified:Iris-versicolor, Actual:Iris-versicolor
Classified:Iris-versicolor, Actual:Iris-versicolor
Classified:Iris-versicolor, Actual:Iris-versicolor
Classified:Iris-versicolor, Actual:Iris-versicolor
Classified:Iris-versicolor, Actual:Iris-versicolor
Classified:Iris-versicolor, Actual:Iris-versicolor
Classified:Iris-versicolor, Actual:Iris-versicolor
Classified:Iris-versicolor, Actual:Iris-versicolor
Classified:Iris-versicolor, Actual:Iris-versicolor
Classified:Iris-versicolor, Actual:Iris-versicolor
Classified:Iris-versicolor, Actual:Iris-versicolor
Classified:Iris-versicolor, Actual:Iris-versicolor
Classified:Iris-virginica, Actual:Iris-virginica
Classified:Iris-versicolor, Actual:Iris-virginica
Classified:Iris-virginica, Actual:Iris-virginica
Classified:Iris-virginica, Actual:Iris-virginica
Classified:Iris-virginica, Actual:Iris-virginica
Classified:Iris-virginica, Actual:Iris-virginica
Classified:Iris-virginica, Actual:Iris-virginica
Classified:Iris-virginica, Actual:Iris-virginica
Classified:Iris-virginica, Actual:Iris-virginica
Classified:Iris-virginica, Actual:Iris-virginica
Classified:Iris-virginica, Actual:Iris-virginica
Classified:Iris-virginica, Actual:Iris-virginica
Classified:Iris-virginica, Actual:Iris-virginica
Classified:Iris-versicolor, Actual:Iris-virginica

Accuracy: 96.1538461538%

6. kNN 분류 전체 코드

  • 중간 과정의 테스트 코드 삭제

  • 보다 정확한 정확도 측정을 위하여 전체적으로 num_trials번의 테스트 후 평균 정확도 산출

In [27]:
import urllib2
import json
from scipy import stats
from pandas import Series, DataFrame
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import random
import math
import operator

def splitDataset(split, df, training_set=[] , test_set=[]):
    for i in range(len(df)):
        if random.random() < split:
            training_set.append(df.ix[i])
        else:
            test_set.append(df.ix[i])
    return training_set, test_set        

def euclideanDistance(instance1, instance2):
    distance = 0
    for x in range(num_feature):
        distance += pow((instance1[x] - instance2[x]), 2)
    return math.sqrt(distance)

def getNeighbors(training_set, test_instance, k):
    distances = []
    for i in range(len(training_set)):
        dist = euclideanDistance(training_set[i], test_instance)
        distances.append((training_set[i], dist))
    distances.sort(key=operator.itemgetter(1))
    neighbors = []
    for i in range(k):
        neighbors.append(distances[i][0])
    return neighbors

def classify(neighbors):
    class_frequency = {}
    for i in range(len(neighbors)):
        class_name = neighbors[i][-1]
        if class_name in class_frequency:
            class_frequency[class_name] += 1
        else:
            class_frequency[class_name] = 1
    sorted_class_frequency = sorted(class_frequency.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sorted_class_frequency[0][0]

if __name__ == '__main__':
    path = 'https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data'
    raw_csv = urllib2.urlopen(path)
    feature_names = ('sepal length', 'sepal width', 'petal length', 'petal width')
    iris_names = ('Iris-setosa', 'Iris-versicolor', 'Iris-virginica')
    all_names = feature_names + ('class',)
    df = pd.read_csv(raw_csv, names=all_names)
    df_feature = df.drop('class', axis=1)
    num_feature = len(feature_names)
    split = 0.66
    k = 3
    num_trials = 3
    accuracy_sum = 0.0

    for i in range(num_trials):
        training_set, test_set = splitDataset(split, df)
        classified_class_names=[]
        for i in range(len(test_set)):
            neighbors = getNeighbors(training_set, test_set[i], k)
            result = classify(neighbors)
            classified_class_names.append(result)

        correct = 0.0
        for i in range(len(test_set)):
            if test_set[i][-1] == classified_class_names[i]:
                correct += 1.0

        accuracy_sum += (correct / float(len(test_set))) * 100.0

    print('Mean Accuracy: ' + str(accuracy_sum / float(num_trials)) + '%')
Mean Accuracy: 98.6579403572%

7. scikit-learn을 활용한 kNN 수행

  • sklearn.datasets.load_iris()를 제공하여 iris 데이터를 편하게 로드할 수 있음
In [23]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import neighbors, datasets

iris = datasets.load_iris()
print iris.data[0:5]
print iris.target[0:5]
[[ 5.1  3.5  1.4  0.2]
 [ 4.9  3.   1.4  0.2]
 [ 4.7  3.2  1.3  0.2]
 [ 4.6  3.1  1.5  0.2]
 [ 5.   3.6  1.4  0.2]]
[0 0 0 0 0]
  • datasets.load_iris()가 반환한 데이터 형태에 맞게 훈련 데이터와 테스트 데이터 구분
In [29]:
import random

def splitDataset2(split, data, training_feature_set=[], training_target_set=[], test_feature_set=[], test_target_set=[]):
    for i in range(len(data)):
        if random.random() < split:
            training_feature_set.append(iris.data[i])
            training_target_set.append(iris.target[i])
        else:
            test_feature_set.append(iris.data[i])
            test_target_set.append(iris.target[i])
    return training_feature_set, training_target_set, test_feature_set, test_target_set

split = 0.66
training_feature_set, training_target_set, test_feature_set, test_target_set = splitDataset2(split, iris.data)
print 'Train: ' + str(len(training_feature_set)) + " - ratio: " + str(float(len(training_feature_set))/len(df))
print 'Test: ' + str(len(test_feature_set)) + " - ratio: " + str(float(len(test_feature_set))/len(df))
print
print training_feature_set
print training_target_set
print
print test_feature_set
print test_target_set
Train: 99 - ratio: 0.66
Test: 51 - ratio: 0.34

[array([ 5.1,  3.5,  1.4,  0.2]), array([ 4.9,  3. ,  1.4,  0.2]), array([ 4.7,  3.2,  1.3,  0.2]), array([ 5.4,  3.9,  1.7,  0.4]), array([ 4.4,  2.9,  1.4,  0.2]), array([ 4.8,  3.4,  1.6,  0.2]), array([ 4.8,  3. ,  1.4,  0.1]), array([ 4.3,  3. ,  1.1,  0.1]), array([ 5.8,  4. ,  1.2,  0.2]), array([ 5.7,  4.4,  1.5,  0.4]), array([ 5.4,  3.9,  1.3,  0.4]), array([ 5.1,  3.5,  1.4,  0.3]), array([ 5.7,  3.8,  1.7,  0.3]), array([ 5.1,  3.8,  1.5,  0.3]), array([ 5.4,  3.4,  1.7,  0.2]), array([ 5.1,  3.7,  1.5,  0.4]), array([ 4.6,  3.6,  1. ,  0.2]), array([ 5.1,  3.3,  1.7,  0.5]), array([ 4.8,  3.4,  1.9,  0.2]), array([ 5. ,  3. ,  1.6,  0.2]), array([ 5.2,  3.5,  1.5,  0.2]), array([ 4.7,  3.2,  1.6,  0.2]), array([ 5.4,  3.4,  1.5,  0.4]), array([ 5.5,  4.2,  1.4,  0.2]), array([ 5. ,  3.2,  1.2,  0.2]), array([ 5.5,  3.5,  1.3,  0.2]), array([ 4.9,  3.1,  1.5,  0.1]), array([ 4.4,  3. ,  1.3,  0.2]), array([ 5.1,  3.4,  1.5,  0.2]), array([ 4.5,  2.3,  1.3,  0.3]), array([ 5.1,  3.8,  1.6,  0.2]), array([ 5.3,  3.7,  1.5,  0.2]), array([ 6.4,  3.2,  4.5,  1.5]), array([ 5.5,  2.3,  4. ,  1.3]), array([ 6.5,  2.8,  4.6,  1.5]), array([ 5.7,  2.8,  4.5,  1.3]), array([ 6.3,  3.3,  4.7,  1.6]), array([ 4.9,  2.4,  3.3,  1. ]), array([ 5.2,  2.7,  3.9,  1.4]), array([ 5. ,  2. ,  3.5,  1. ]), array([ 5.9,  3. ,  4.2,  1.5]), array([ 6. ,  2.2,  4. ,  1. ]), array([ 6.1,  2.9,  4.7,  1.4]), array([ 5.6,  2.9,  3.6,  1.3]), array([ 6.7,  3.1,  4.4,  1.4]), array([ 5.8,  2.7,  4.1,  1. ]), array([ 6.2,  2.2,  4.5,  1.5]), array([ 5.6,  2.5,  3.9,  1.1]), array([ 6.1,  2.8,  4. ,  1.3]), array([ 6.1,  2.8,  4.7,  1.2]), array([ 6.4,  2.9,  4.3,  1.3]), array([ 6.8,  2.8,  4.8,  1.4]), array([ 6.7,  3. ,  5. ,  1.7]), array([ 5.7,  2.6,  3.5,  1. ]), array([ 5.5,  2.4,  3.8,  1.1]), array([ 5.5,  2.4,  3.7,  1. ]), array([ 6. ,  2.7,  5.1,  1.6]), array([ 6. ,  3.4,  4.5,  1.6]), array([ 6.7,  3.1,  4.7,  1.5]), array([ 6.3,  2.3,  4.4,  1.3]), array([ 5.5,  2.5,  4. ,  1.3]), array([ 5.5,  2.6,  4.4,  1.2]), array([ 5.6,  2.7,  4.2,  1.3]), array([ 5.7,  3. ,  4.2,  1.2]), array([ 5.7,  2.9,  4.2,  1.3]), array([ 6.2,  2.9,  4.3,  1.3]), array([ 5.1,  2.5,  3. ,  1.1]), array([ 5.7,  2.8,  4.1,  1.3]), array([ 5.8,  2.7,  5.1,  1.9]), array([ 6.3,  2.9,  5.6,  1.8]), array([ 7.6,  3. ,  6.6,  2.1]), array([ 6.7,  2.5,  5.8,  1.8]), array([ 7.2,  3.6,  6.1,  2.5]), array([ 6.4,  2.7,  5.3,  1.9]), array([ 6.8,  3. ,  5.5,  2.1]), array([ 5.8,  2.8,  5.1,  2.4]), array([ 6.5,  3. ,  5.5,  1.8]), array([ 6. ,  2.2,  5. ,  1.5]), array([ 6.9,  3.2,  5.7,  2.3]), array([ 5.6,  2.8,  4.9,  2. ]), array([ 7.7,  2.8,  6.7,  2. ]), array([ 6.3,  2.7,  4.9,  1.8]), array([ 7.2,  3.2,  6. ,  1.8]), array([ 6.2,  2.8,  4.8,  1.8]), array([ 6.4,  2.8,  5.6,  2.1]), array([ 7.2,  3. ,  5.8,  1.6]), array([ 7.4,  2.8,  6.1,  1.9]), array([ 6.4,  2.8,  5.6,  2.2]), array([ 6.3,  2.8,  5.1,  1.5]), array([ 6.1,  2.6,  5.6,  1.4]), array([ 6.3,  3.4,  5.6,  2.4]), array([ 6.9,  3.1,  5.4,  2.1]), array([ 6.9,  3.1,  5.1,  2.3]), array([ 5.8,  2.7,  5.1,  1.9]), array([ 6.7,  3. ,  5.2,  2.3]), array([ 6.3,  2.5,  5. ,  1.9]), array([ 6.5,  3. ,  5.2,  2. ]), array([ 6.2,  3.4,  5.4,  2.3]), array([ 5.9,  3. ,  5.1,  1.8])]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]

[array([ 4.6,  3.1,  1.5,  0.2]), array([ 5. ,  3.6,  1.4,  0.2]), array([ 4.6,  3.4,  1.4,  0.3]), array([ 5. ,  3.4,  1.5,  0.2]), array([ 4.9,  3.1,  1.5,  0.1]), array([ 5.4,  3.7,  1.5,  0.2]), array([ 5. ,  3.4,  1.6,  0.4]), array([ 5.2,  3.4,  1.4,  0.2]), array([ 4.8,  3.1,  1.6,  0.2]), array([ 5.2,  4.1,  1.5,  0.1]), array([ 4.9,  3.1,  1.5,  0.1]), array([ 5. ,  3.5,  1.3,  0.3]), array([ 4.4,  3.2,  1.3,  0.2]), array([ 5. ,  3.5,  1.6,  0.6]), array([ 5.1,  3.8,  1.9,  0.4]), array([ 4.8,  3. ,  1.4,  0.3]), array([ 4.6,  3.2,  1.4,  0.2]), array([ 5. ,  3.3,  1.4,  0.2]), array([ 7. ,  3.2,  4.7,  1.4]), array([ 6.9,  3.1,  4.9,  1.5]), array([ 6.6,  2.9,  4.6,  1.3]), array([ 5.6,  3. ,  4.5,  1.5]), array([ 5.9,  3.2,  4.8,  1.8]), array([ 6.3,  2.5,  4.9,  1.5]), array([ 6.6,  3. ,  4.4,  1.4]), array([ 6. ,  2.9,  4.5,  1.5]), array([ 5.8,  2.7,  3.9,  1.2]), array([ 5.4,  3. ,  4.5,  1.5]), array([ 5.6,  3. ,  4.1,  1.3]), array([ 6.1,  3. ,  4.6,  1.4]), array([ 5.8,  2.6,  4. ,  1.2]), array([ 5. ,  2.3,  3.3,  1. ]), array([ 6.3,  3.3,  6. ,  2.5]), array([ 7.1,  3. ,  5.9,  2.1]), array([ 6.5,  3. ,  5.8,  2.2]), array([ 4.9,  2.5,  4.5,  1.7]), array([ 7.3,  2.9,  6.3,  1.8]), array([ 6.5,  3.2,  5.1,  2. ]), array([ 5.7,  2.5,  5. ,  2. ]), array([ 6.4,  3.2,  5.3,  2.3]), array([ 7.7,  3.8,  6.7,  2.2]), array([ 7.7,  2.6,  6.9,  2.3]), array([ 6.7,  3.3,  5.7,  2.1]), array([ 6.1,  3. ,  4.9,  1.8]), array([ 7.9,  3.8,  6.4,  2. ]), array([ 7.7,  3. ,  6.1,  2.3]), array([ 6.4,  3.1,  5.5,  1.8]), array([ 6. ,  3. ,  4.8,  1.8]), array([ 6.7,  3.1,  5.6,  2.4]), array([ 6.8,  3.2,  5.9,  2.3]), array([ 6.7,  3.3,  5.7,  2.5])]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
  • knn.fit(훈련 속성 데이터 집합, 훈련 분류 데이터 집합)을 통하여 knn 알고리즘 훈련 및 모델 형성
  • knn.predict(테스트 속성 데이터)를 통해 테스트 데이터의 분류 결과를 얻어옴
In [30]:
k = 3
knn = neighbors.KNeighborsClassifier(k)
knn.fit(training_feature_set, training_target_set)
result_index = knn.predict(test_feature_set[0])
print('Classified:' + iris_names[result_index] + ', Actual:' + iris_names[test_target_set[0]])
Classified:Iris-setosa, Actual:Iris-setosa
  • 전체 코드
In [31]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import neighbors, datasets

iris = datasets.load_iris()
    
def splitDataset2(split, data, training_feature_set=[], training_target_set=[], test_feature_set=[], test_target_set=[]):
    for i in range(len(data)):
        if random.random() < split:
            training_feature_set.append(iris.data[i])
            training_target_set.append(iris.target[i])
        else:
            test_feature_set.append(iris.data[i])
            test_target_set.append(iris.target[i])
    return training_feature_set, training_target_set, test_feature_set, test_target_set

if __name__ == '__main__':
    feature_names = ('sepal length', 'sepal width', 'petal length', 'petal width')
    iris_names = ('Iris-setosa', 'Iris-versicolor', 'Iris-virginica')
    all_names = feature_names + ('class',)
    num_feature = len(feature_names)
    split = 0.66
    k = 3
    num_trials = 3
    accuracy_sum = 0.0

    for i in range(num_trials):
        training_feature_set, training_target_set, test_feature_set, test_target_set = splitDataset2(split, iris.data)
        knn = neighbors.KNeighborsClassifier(k)
        knn.fit(training_feature_set, training_target_set)
        
        classified_class_names=[]
        for i in range(len(test_feature_set)):
            result_index = knn.predict(test_feature_set[i])
            classified_class_names.append(iris_names[result_index])
        
        correct = 0.0
        for i in range(len(test_feature_set)):
            if iris_names[test_target_set[i]] == classified_class_names[i]:
                correct += 1.0

        accuracy_sum += (correct / float(len(test_feature_set))) * 100.0

    print('Mean Accuracy: ' + str(accuracy_sum / float(num_trials)) + '%')
Mean Accuracy: 97.1169784284%

8. Refererence