from sklearn.datasets import load_iris
iris_dataset = load_iris()
print("Keys of iris_dataset: {}".format(iris_dataset.keys()))
Keys of iris_dataset: dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names'])
print(iris_dataset['DESCR'])
Iris Plants Database ==================== Notes ----- Data Set Characteristics: :Number of Instances: 150 (50 in each of three classes) :Number of Attributes: 4 numeric, predictive attributes and the class :Attribute Information: - sepal length in cm - sepal width in cm - petal length in cm - petal width in cm - class: - Iris-Setosa - Iris-Versicolour - Iris-Virginica :Summary Statistics: ============== ==== ==== ======= ===== ==================== Min Max Mean SD Class Correlation ============== ==== ==== ======= ===== ==================== sepal length: 4.3 7.9 5.84 0.83 0.7826 sepal width: 2.0 4.4 3.05 0.43 -0.4194 petal length: 1.0 6.9 3.76 1.76 0.9490 (high!) petal width: 0.1 2.5 1.20 0.76 0.9565 (high!) ============== ==== ==== ======= ===== ==================== :Missing Attribute Values: None :Class Distribution: 33.3% for each of 3 classes. :Creator: R.A. Fisher :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov) :Date: July, 1988 This is a copy of UCI ML iris datasets. http://archive.ics.uci.edu/ml/datasets/Iris The famous Iris database, first used by Sir R.A Fisher This is perhaps the best known database to be found in the pattern recognition literature. Fisher's paper is a classic in the field and is referenced frequently to this day. (See Duda & Hart, for example.) The data set contains 3 classes of 50 instances each, where each class refers to a type of iris plant. One class is linearly separable from the other 2; the latter are NOT linearly separable from each other. References ---------- - Fisher,R.A. "The use of multiple measurements in taxonomic problems" Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to Mathematical Statistics" (John Wiley, NY, 1950). - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis. (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218. - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System Structure and Classification Rule for Recognition in Partially Exposed Environments". IEEE Transactions on Pattern Analysis and Machine Intelligence, Vol. PAMI-2, No. 1, 67-71. - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions on Information Theory, May 1972, 431-433. - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS II conceptual clustering system finds 3 classes in the data. - Many, many more ...
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
iris_dataset['data'],
iris_dataset['target'],
#test_size=38,
random_state=0
)
print("X_train shape: {}".format(X_train.shape))
print("y_train shape: {}".format(y_train.shape))
X_train shape: (112, 4) y_train shape: (112,)
print("X_test shape: {}".format(X_test.shape))
print("y_test shape: {}".format(y_test.shape))
X_test shape: (38, 4) y_test shape: (38,)
import pandas as pd
import mglearn
%matplotlib inline
# create dataframe from data in X_train
# label the columns using the strings in iris_dataset.feature_names
iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
# create a scatter matrix from the dataframe, color by y_train
pd.plotting.scatter_matrix(
iris_dataframe,
c=y_train,
figsize=(15, 15),
marker='o',
cmap=mglearn.cm3
)
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x1c19063358>, <matplotlib.axes._subplots.AxesSubplot object at 0x1c191434a8>, <matplotlib.axes._subplots.AxesSubplot object at 0x1c19178518>, <matplotlib.axes._subplots.AxesSubplot object at 0x1c191b15f8>], [<matplotlib.axes._subplots.AxesSubplot object at 0x1c191eb5f8>, <matplotlib.axes._subplots.AxesSubplot object at 0x1c191eb630>, <matplotlib.axes._subplots.AxesSubplot object at 0x1c19258048>, <matplotlib.axes._subplots.AxesSubplot object at 0x1c192914a8>], [<matplotlib.axes._subplots.AxesSubplot object at 0x1c192ca3c8>, <matplotlib.axes._subplots.AxesSubplot object at 0x1c19265710>, <matplotlib.axes._subplots.AxesSubplot object at 0x1c19329ac8>, <matplotlib.axes._subplots.AxesSubplot object at 0x1c19363f98>], [<matplotlib.axes._subplots.AxesSubplot object at 0x1c193a5518>, <matplotlib.axes._subplots.AxesSubplot object at 0x1c193ca320>, <matplotlib.axes._subplots.AxesSubplot object at 0x1c1940a588>, <matplotlib.axes._subplots.AxesSubplot object at 0x1c19441908>]], dtype=object)
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=1)
knn.fit(X_train, y_train)
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski', metric_params=None, n_jobs=1, n_neighbors=1, p=2, weights='uniform')
import numpy as np
X_new = np.array([[5, 2.9, 1, 0.2]])
print("X_new.shape: {}".format(X_new.shape))
X_new.shape: (1, 4)
prediction = knn.predict(X_new)
print("Prediction: {}".format(prediction))
print("Predicted target name: {}".format(
iris_dataset['target_names'][prediction]
))
Prediction: [0] Predicted target name: ['setosa']
y_pred = knn.predict(X_test)
print("Test set predictions:\n {}".format(y_pred))
Test set predictions: [2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0 2]
print("Test set score: {:.2f}".format(np.mean(y_pred == y_test)))
print("Test set score: {:.2f}".format(knn.score(X_test, y_test)))
Test set score: 0.97 Test set score: 0.97