import urllib2
import json
from scipy import stats
from pandas import Series, DataFrame
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
path = 'https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data'
raw_csv = urllib2.urlopen(path)
feature_names = ('sepal length', 'sepal width', 'petal length', 'petal width')
all_names = feature_names + ('class',)
df = pd.read_csv(raw_csv, names=all_names)
print df
sepal length sepal width petal length petal width class 0 5.1 3.5 1.4 0.2 Iris-setosa 1 4.9 3.0 1.4 0.2 Iris-setosa 2 4.7 3.2 1.3 0.2 Iris-setosa 3 4.6 3.1 1.5 0.2 Iris-setosa 4 5.0 3.6 1.4 0.2 Iris-setosa 5 5.4 3.9 1.7 0.4 Iris-setosa 6 4.6 3.4 1.4 0.3 Iris-setosa 7 5.0 3.4 1.5 0.2 Iris-setosa 8 4.4 2.9 1.4 0.2 Iris-setosa 9 4.9 3.1 1.5 0.1 Iris-setosa 10 5.4 3.7 1.5 0.2 Iris-setosa 11 4.8 3.4 1.6 0.2 Iris-setosa 12 4.8 3.0 1.4 0.1 Iris-setosa 13 4.3 3.0 1.1 0.1 Iris-setosa 14 5.8 4.0 1.2 0.2 Iris-setosa 15 5.7 4.4 1.5 0.4 Iris-setosa 16 5.4 3.9 1.3 0.4 Iris-setosa 17 5.1 3.5 1.4 0.3 Iris-setosa 18 5.7 3.8 1.7 0.3 Iris-setosa 19 5.1 3.8 1.5 0.3 Iris-setosa 20 5.4 3.4 1.7 0.2 Iris-setosa 21 5.1 3.7 1.5 0.4 Iris-setosa 22 4.6 3.6 1.0 0.2 Iris-setosa 23 5.1 3.3 1.7 0.5 Iris-setosa 24 4.8 3.4 1.9 0.2 Iris-setosa 25 5.0 3.0 1.6 0.2 Iris-setosa 26 5.0 3.4 1.6 0.4 Iris-setosa 27 5.2 3.5 1.5 0.2 Iris-setosa 28 5.2 3.4 1.4 0.2 Iris-setosa 29 4.7 3.2 1.6 0.2 Iris-setosa .. ... ... ... ... ... 120 6.9 3.2 5.7 2.3 Iris-virginica 121 5.6 2.8 4.9 2.0 Iris-virginica 122 7.7 2.8 6.7 2.0 Iris-virginica 123 6.3 2.7 4.9 1.8 Iris-virginica 124 6.7 3.3 5.7 2.1 Iris-virginica 125 7.2 3.2 6.0 1.8 Iris-virginica 126 6.2 2.8 4.8 1.8 Iris-virginica 127 6.1 3.0 4.9 1.8 Iris-virginica 128 6.4 2.8 5.6 2.1 Iris-virginica 129 7.2 3.0 5.8 1.6 Iris-virginica 130 7.4 2.8 6.1 1.9 Iris-virginica 131 7.9 3.8 6.4 2.0 Iris-virginica 132 6.4 2.8 5.6 2.2 Iris-virginica 133 6.3 2.8 5.1 1.5 Iris-virginica 134 6.1 2.6 5.6 1.4 Iris-virginica 135 7.7 3.0 6.1 2.3 Iris-virginica 136 6.3 3.4 5.6 2.4 Iris-virginica 137 6.4 3.1 5.5 1.8 Iris-virginica 138 6.0 3.0 4.8 1.8 Iris-virginica 139 6.9 3.1 5.4 2.1 Iris-virginica 140 6.7 3.1 5.6 2.4 Iris-virginica 141 6.9 3.1 5.1 2.3 Iris-virginica 142 5.8 2.7 5.1 1.9 Iris-virginica 143 6.8 3.2 5.9 2.3 Iris-virginica 144 6.7 3.3 5.7 2.5 Iris-virginica 145 6.7 3.0 5.2 2.3 Iris-virginica 146 6.3 2.5 5.0 1.9 Iris-virginica 147 6.5 3.0 5.2 2.0 Iris-virginica 148 6.2 3.4 5.4 2.3 Iris-virginica 149 5.9 3.0 5.1 1.8 Iris-virginica [150 rows x 5 columns]
import random
def splitDataset(split, df, training_set=[], test_set=[]):
for i in range(len(df)):
if random.random() < split:
training_set.append(df.ix[i])
else:
test_set.append(df.ix[i])
return training_set, test_set
split = 0.66
training_set, test_set = splitDataset(split, df)
print 'Train: ' + str(len(training_set)) + " - ratio: " + str(float(len(training_set))/len(df))
print 'Test: ' + str(len(test_set)) + " - ratio: " + str(float(len(test_set))/len(df))
Train: 98 - ratio: 0.653333333333 Test: 52 - ratio: 0.346666666667
num_feature = len(feature_names)
import math
def euclideanDistance(instance1, instance2):
distance = 0
for x in range(num_feature):
distance += pow((instance1[x] - instance2[x]), 2)
return math.sqrt(distance)
df_feature = df.drop('class', axis=1)
print df_feature.head()
print
distance = euclideanDistance(df_feature.ix[0], df_feature.ix[1])
print 'Distance: ' + str(distance)
sepal length sepal width petal length petal width 0 5.1 3.5 1.4 0.2 1 4.9 3.0 1.4 0.2 2 4.7 3.2 1.3 0.2 3 4.6 3.1 1.5 0.2 4 5.0 3.6 1.4 0.2 Distance: 0.538516480713
import operator
def getNeighbors(training_set, test_instance, k):
distances = []
for i in range(len(training_set)):
dist = euclideanDistance(training_set[i], test_instance)
distances.append((training_set[i], dist))
distances.sort(key=operator.itemgetter(1))
neighbors = []
for i in range(k):
neighbors.append(distances[i][0])
return neighbors
print test_set[0]
print
k = 1
neighbors = getNeighbors(training_set, test_set[0], k)
print neighbors
sepal length 4.9 sepal width 3 petal length 1.4 petal width 0.2 class Iris-setosa Name: 1, dtype: object [sepal length 4.8 sepal width 3 petal length 1.4 petal width 0.1 class Iris-setosa Name: 12, dtype: object]
print neighbors[0]
print
print type(neighbors[0])
print
print neighbors[0][-1]
sepal length 4.8 sepal width 3 petal length 1.4 petal width 0.1 class Iris-setosa Name: 12, dtype: object <class 'pandas.core.series.Series'> Iris-setosa
k = 3
neighbors = getNeighbors(training_set, test_set[0], k)
print(neighbors)
[sepal length 4.8 sepal width 3 petal length 1.4 petal width 0.1 class Iris-setosa Name: 12, dtype: object, sepal length 4.9 sepal width 3.1 petal length 1.5 petal width 0.1 class Iris-setosa Name: 9, dtype: object, sepal length 4.9 sepal width 3.1 petal length 1.5 petal width 0.1 class Iris-setosa Name: 34, dtype: object]
def classify(neighbors):
class_frequency = {}
for i in range(len(neighbors)):
class_name = neighbors[i][-1]
if class_name in class_frequency:
class_frequency[class_name] += 1
else:
class_frequency[class_name] = 1
sorted_class_frequency = sorted(class_frequency.iteritems(), key=operator.itemgetter(1), reverse=True)
return sorted_class_frequency[0][0]
k = 3
neighbors = getNeighbors(training_set, test_set[0], k)
classified_class_name = classify(neighbors)
print "Classified:", classified_class_name, "- Actual:", test_set[0]['class']
Classified: Iris-setosa - Actual: Iris-setosa
k = 3
classified_class_names=[]
for i in range(len(test_set)):
neighbors = getNeighbors(training_set, test_set[i], k)
result = classify(neighbors)
classified_class_names.append(result)
print('Classified:' + result + ', Actual:' + test_set[i][-1])
correct = 0.0
for i in range(len(test_set)):
if classified_class_names[i] == test_set[i][-1]:
correct += 1.0
print
print('Accuracy: ' + str(correct / float(len(test_set)) * 100.0) + '%')
Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-setosa, Actual:Iris-setosa Classified:Iris-versicolor, Actual:Iris-versicolor Classified:Iris-versicolor, Actual:Iris-versicolor Classified:Iris-versicolor, Actual:Iris-versicolor Classified:Iris-versicolor, Actual:Iris-versicolor Classified:Iris-versicolor, Actual:Iris-versicolor Classified:Iris-versicolor, Actual:Iris-versicolor Classified:Iris-versicolor, Actual:Iris-versicolor Classified:Iris-versicolor, Actual:Iris-versicolor Classified:Iris-versicolor, Actual:Iris-versicolor Classified:Iris-versicolor, Actual:Iris-versicolor Classified:Iris-versicolor, Actual:Iris-versicolor Classified:Iris-versicolor, Actual:Iris-versicolor Classified:Iris-versicolor, Actual:Iris-versicolor Classified:Iris-versicolor, Actual:Iris-versicolor Classified:Iris-versicolor, Actual:Iris-versicolor Classified:Iris-virginica, Actual:Iris-virginica Classified:Iris-versicolor, Actual:Iris-virginica Classified:Iris-virginica, Actual:Iris-virginica Classified:Iris-virginica, Actual:Iris-virginica Classified:Iris-virginica, Actual:Iris-virginica Classified:Iris-virginica, Actual:Iris-virginica Classified:Iris-virginica, Actual:Iris-virginica Classified:Iris-virginica, Actual:Iris-virginica Classified:Iris-virginica, Actual:Iris-virginica Classified:Iris-virginica, Actual:Iris-virginica Classified:Iris-virginica, Actual:Iris-virginica Classified:Iris-virginica, Actual:Iris-virginica Classified:Iris-virginica, Actual:Iris-virginica Classified:Iris-versicolor, Actual:Iris-virginica Accuracy: 96.1538461538%
중간 과정의 테스트 코드 삭제
보다 정확한 정확도 측정을 위하여 전체적으로 num_trials번의 테스트 후 평균 정확도 산출
import urllib2
import json
from scipy import stats
from pandas import Series, DataFrame
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import random
import math
import operator
def splitDataset(split, df, training_set=[] , test_set=[]):
for i in range(len(df)):
if random.random() < split:
training_set.append(df.ix[i])
else:
test_set.append(df.ix[i])
return training_set, test_set
def euclideanDistance(instance1, instance2):
distance = 0
for x in range(num_feature):
distance += pow((instance1[x] - instance2[x]), 2)
return math.sqrt(distance)
def getNeighbors(training_set, test_instance, k):
distances = []
for i in range(len(training_set)):
dist = euclideanDistance(training_set[i], test_instance)
distances.append((training_set[i], dist))
distances.sort(key=operator.itemgetter(1))
neighbors = []
for i in range(k):
neighbors.append(distances[i][0])
return neighbors
def classify(neighbors):
class_frequency = {}
for i in range(len(neighbors)):
class_name = neighbors[i][-1]
if class_name in class_frequency:
class_frequency[class_name] += 1
else:
class_frequency[class_name] = 1
sorted_class_frequency = sorted(class_frequency.iteritems(), key=operator.itemgetter(1), reverse=True)
return sorted_class_frequency[0][0]
if __name__ == '__main__':
path = 'https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data'
raw_csv = urllib2.urlopen(path)
feature_names = ('sepal length', 'sepal width', 'petal length', 'petal width')
iris_names = ('Iris-setosa', 'Iris-versicolor', 'Iris-virginica')
all_names = feature_names + ('class',)
df = pd.read_csv(raw_csv, names=all_names)
df_feature = df.drop('class', axis=1)
num_feature = len(feature_names)
split = 0.66
k = 3
num_trials = 3
accuracy_sum = 0.0
for i in range(num_trials):
training_set, test_set = splitDataset(split, df)
classified_class_names=[]
for i in range(len(test_set)):
neighbors = getNeighbors(training_set, test_set[i], k)
result = classify(neighbors)
classified_class_names.append(result)
correct = 0.0
for i in range(len(test_set)):
if test_set[i][-1] == classified_class_names[i]:
correct += 1.0
accuracy_sum += (correct / float(len(test_set))) * 100.0
print('Mean Accuracy: ' + str(accuracy_sum / float(num_trials)) + '%')
Mean Accuracy: 98.6579403572%
import numpy as np
import matplotlib.pyplot as plt
from sklearn import neighbors, datasets
iris = datasets.load_iris()
print iris.data[0:5]
print iris.target[0:5]
[[ 5.1 3.5 1.4 0.2] [ 4.9 3. 1.4 0.2] [ 4.7 3.2 1.3 0.2] [ 4.6 3.1 1.5 0.2] [ 5. 3.6 1.4 0.2]] [0 0 0 0 0]
import random
def splitDataset2(split, data, training_feature_set=[], training_target_set=[], test_feature_set=[], test_target_set=[]):
for i in range(len(data)):
if random.random() < split:
training_feature_set.append(iris.data[i])
training_target_set.append(iris.target[i])
else:
test_feature_set.append(iris.data[i])
test_target_set.append(iris.target[i])
return training_feature_set, training_target_set, test_feature_set, test_target_set
split = 0.66
training_feature_set, training_target_set, test_feature_set, test_target_set = splitDataset2(split, iris.data)
print 'Train: ' + str(len(training_feature_set)) + " - ratio: " + str(float(len(training_feature_set))/len(df))
print 'Test: ' + str(len(test_feature_set)) + " - ratio: " + str(float(len(test_feature_set))/len(df))
print
print training_feature_set
print training_target_set
print
print test_feature_set
print test_target_set
Train: 99 - ratio: 0.66 Test: 51 - ratio: 0.34 [array([ 5.1, 3.5, 1.4, 0.2]), array([ 4.9, 3. , 1.4, 0.2]), array([ 4.7, 3.2, 1.3, 0.2]), array([ 5.4, 3.9, 1.7, 0.4]), array([ 4.4, 2.9, 1.4, 0.2]), array([ 4.8, 3.4, 1.6, 0.2]), array([ 4.8, 3. , 1.4, 0.1]), array([ 4.3, 3. , 1.1, 0.1]), array([ 5.8, 4. , 1.2, 0.2]), array([ 5.7, 4.4, 1.5, 0.4]), array([ 5.4, 3.9, 1.3, 0.4]), array([ 5.1, 3.5, 1.4, 0.3]), array([ 5.7, 3.8, 1.7, 0.3]), array([ 5.1, 3.8, 1.5, 0.3]), array([ 5.4, 3.4, 1.7, 0.2]), array([ 5.1, 3.7, 1.5, 0.4]), array([ 4.6, 3.6, 1. , 0.2]), array([ 5.1, 3.3, 1.7, 0.5]), array([ 4.8, 3.4, 1.9, 0.2]), array([ 5. , 3. , 1.6, 0.2]), array([ 5.2, 3.5, 1.5, 0.2]), array([ 4.7, 3.2, 1.6, 0.2]), array([ 5.4, 3.4, 1.5, 0.4]), array([ 5.5, 4.2, 1.4, 0.2]), array([ 5. , 3.2, 1.2, 0.2]), array([ 5.5, 3.5, 1.3, 0.2]), array([ 4.9, 3.1, 1.5, 0.1]), array([ 4.4, 3. , 1.3, 0.2]), array([ 5.1, 3.4, 1.5, 0.2]), array([ 4.5, 2.3, 1.3, 0.3]), array([ 5.1, 3.8, 1.6, 0.2]), array([ 5.3, 3.7, 1.5, 0.2]), array([ 6.4, 3.2, 4.5, 1.5]), array([ 5.5, 2.3, 4. , 1.3]), array([ 6.5, 2.8, 4.6, 1.5]), array([ 5.7, 2.8, 4.5, 1.3]), array([ 6.3, 3.3, 4.7, 1.6]), array([ 4.9, 2.4, 3.3, 1. ]), array([ 5.2, 2.7, 3.9, 1.4]), array([ 5. , 2. , 3.5, 1. ]), array([ 5.9, 3. , 4.2, 1.5]), array([ 6. , 2.2, 4. , 1. ]), array([ 6.1, 2.9, 4.7, 1.4]), array([ 5.6, 2.9, 3.6, 1.3]), array([ 6.7, 3.1, 4.4, 1.4]), array([ 5.8, 2.7, 4.1, 1. ]), array([ 6.2, 2.2, 4.5, 1.5]), array([ 5.6, 2.5, 3.9, 1.1]), array([ 6.1, 2.8, 4. , 1.3]), array([ 6.1, 2.8, 4.7, 1.2]), array([ 6.4, 2.9, 4.3, 1.3]), array([ 6.8, 2.8, 4.8, 1.4]), array([ 6.7, 3. , 5. , 1.7]), array([ 5.7, 2.6, 3.5, 1. ]), array([ 5.5, 2.4, 3.8, 1.1]), array([ 5.5, 2.4, 3.7, 1. ]), array([ 6. , 2.7, 5.1, 1.6]), array([ 6. , 3.4, 4.5, 1.6]), array([ 6.7, 3.1, 4.7, 1.5]), array([ 6.3, 2.3, 4.4, 1.3]), array([ 5.5, 2.5, 4. , 1.3]), array([ 5.5, 2.6, 4.4, 1.2]), array([ 5.6, 2.7, 4.2, 1.3]), array([ 5.7, 3. , 4.2, 1.2]), array([ 5.7, 2.9, 4.2, 1.3]), array([ 6.2, 2.9, 4.3, 1.3]), array([ 5.1, 2.5, 3. , 1.1]), array([ 5.7, 2.8, 4.1, 1.3]), array([ 5.8, 2.7, 5.1, 1.9]), array([ 6.3, 2.9, 5.6, 1.8]), array([ 7.6, 3. , 6.6, 2.1]), array([ 6.7, 2.5, 5.8, 1.8]), array([ 7.2, 3.6, 6.1, 2.5]), array([ 6.4, 2.7, 5.3, 1.9]), array([ 6.8, 3. , 5.5, 2.1]), array([ 5.8, 2.8, 5.1, 2.4]), array([ 6.5, 3. , 5.5, 1.8]), array([ 6. , 2.2, 5. , 1.5]), array([ 6.9, 3.2, 5.7, 2.3]), array([ 5.6, 2.8, 4.9, 2. ]), array([ 7.7, 2.8, 6.7, 2. ]), array([ 6.3, 2.7, 4.9, 1.8]), array([ 7.2, 3.2, 6. , 1.8]), array([ 6.2, 2.8, 4.8, 1.8]), array([ 6.4, 2.8, 5.6, 2.1]), array([ 7.2, 3. , 5.8, 1.6]), array([ 7.4, 2.8, 6.1, 1.9]), array([ 6.4, 2.8, 5.6, 2.2]), array([ 6.3, 2.8, 5.1, 1.5]), array([ 6.1, 2.6, 5.6, 1.4]), array([ 6.3, 3.4, 5.6, 2.4]), array([ 6.9, 3.1, 5.4, 2.1]), array([ 6.9, 3.1, 5.1, 2.3]), array([ 5.8, 2.7, 5.1, 1.9]), array([ 6.7, 3. , 5.2, 2.3]), array([ 6.3, 2.5, 5. , 1.9]), array([ 6.5, 3. , 5.2, 2. ]), array([ 6.2, 3.4, 5.4, 2.3]), array([ 5.9, 3. , 5.1, 1.8])] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] [array([ 4.6, 3.1, 1.5, 0.2]), array([ 5. , 3.6, 1.4, 0.2]), array([ 4.6, 3.4, 1.4, 0.3]), array([ 5. , 3.4, 1.5, 0.2]), array([ 4.9, 3.1, 1.5, 0.1]), array([ 5.4, 3.7, 1.5, 0.2]), array([ 5. , 3.4, 1.6, 0.4]), array([ 5.2, 3.4, 1.4, 0.2]), array([ 4.8, 3.1, 1.6, 0.2]), array([ 5.2, 4.1, 1.5, 0.1]), array([ 4.9, 3.1, 1.5, 0.1]), array([ 5. , 3.5, 1.3, 0.3]), array([ 4.4, 3.2, 1.3, 0.2]), array([ 5. , 3.5, 1.6, 0.6]), array([ 5.1, 3.8, 1.9, 0.4]), array([ 4.8, 3. , 1.4, 0.3]), array([ 4.6, 3.2, 1.4, 0.2]), array([ 5. , 3.3, 1.4, 0.2]), array([ 7. , 3.2, 4.7, 1.4]), array([ 6.9, 3.1, 4.9, 1.5]), array([ 6.6, 2.9, 4.6, 1.3]), array([ 5.6, 3. , 4.5, 1.5]), array([ 5.9, 3.2, 4.8, 1.8]), array([ 6.3, 2.5, 4.9, 1.5]), array([ 6.6, 3. , 4.4, 1.4]), array([ 6. , 2.9, 4.5, 1.5]), array([ 5.8, 2.7, 3.9, 1.2]), array([ 5.4, 3. , 4.5, 1.5]), array([ 5.6, 3. , 4.1, 1.3]), array([ 6.1, 3. , 4.6, 1.4]), array([ 5.8, 2.6, 4. , 1.2]), array([ 5. , 2.3, 3.3, 1. ]), array([ 6.3, 3.3, 6. , 2.5]), array([ 7.1, 3. , 5.9, 2.1]), array([ 6.5, 3. , 5.8, 2.2]), array([ 4.9, 2.5, 4.5, 1.7]), array([ 7.3, 2.9, 6.3, 1.8]), array([ 6.5, 3.2, 5.1, 2. ]), array([ 5.7, 2.5, 5. , 2. ]), array([ 6.4, 3.2, 5.3, 2.3]), array([ 7.7, 3.8, 6.7, 2.2]), array([ 7.7, 2.6, 6.9, 2.3]), array([ 6.7, 3.3, 5.7, 2.1]), array([ 6.1, 3. , 4.9, 1.8]), array([ 7.9, 3.8, 6.4, 2. ]), array([ 7.7, 3. , 6.1, 2.3]), array([ 6.4, 3.1, 5.5, 1.8]), array([ 6. , 3. , 4.8, 1.8]), array([ 6.7, 3.1, 5.6, 2.4]), array([ 6.8, 3.2, 5.9, 2.3]), array([ 6.7, 3.3, 5.7, 2.5])] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
k = 3
knn = neighbors.KNeighborsClassifier(k)
knn.fit(training_feature_set, training_target_set)
result_index = knn.predict(test_feature_set[0])
print('Classified:' + iris_names[result_index] + ', Actual:' + iris_names[test_target_set[0]])
Classified:Iris-setosa, Actual:Iris-setosa
import numpy as np
import matplotlib.pyplot as plt
from sklearn import neighbors, datasets
iris = datasets.load_iris()
def splitDataset2(split, data, training_feature_set=[], training_target_set=[], test_feature_set=[], test_target_set=[]):
for i in range(len(data)):
if random.random() < split:
training_feature_set.append(iris.data[i])
training_target_set.append(iris.target[i])
else:
test_feature_set.append(iris.data[i])
test_target_set.append(iris.target[i])
return training_feature_set, training_target_set, test_feature_set, test_target_set
if __name__ == '__main__':
feature_names = ('sepal length', 'sepal width', 'petal length', 'petal width')
iris_names = ('Iris-setosa', 'Iris-versicolor', 'Iris-virginica')
all_names = feature_names + ('class',)
num_feature = len(feature_names)
split = 0.66
k = 3
num_trials = 3
accuracy_sum = 0.0
for i in range(num_trials):
training_feature_set, training_target_set, test_feature_set, test_target_set = splitDataset2(split, iris.data)
knn = neighbors.KNeighborsClassifier(k)
knn.fit(training_feature_set, training_target_set)
classified_class_names=[]
for i in range(len(test_feature_set)):
result_index = knn.predict(test_feature_set[i])
classified_class_names.append(iris_names[result_index])
correct = 0.0
for i in range(len(test_feature_set)):
if iris_names[test_target_set[i]] == classified_class_names[i]:
correct += 1.0
accuracy_sum += (correct / float(len(test_feature_set))) * 100.0
print('Mean Accuracy: ' + str(accuracy_sum / float(num_trials)) + '%')
Mean Accuracy: 97.1169784284%