from sklearn import datasets, svm, metrics, cross_validation
from matplotlib import pyplot as pl
import numpy as np
import pickle
%matplotlib inline
%run util.py
digitos = datasets.load_digits()
Data in scikit-learn, with very few exceptions, is assumed to be stored as a two-dimensional array, of size [n_samples, n_features].
digitos.data.shape
(1797, 64)
digitos.data[5]
array([ 0., 0., 12., 10., 0., 0., 0., 0., 0., 0., 14., 16., 16., 14., 0., 0., 0., 0., 13., 16., 15., 10., 1., 0., 0., 0., 11., 16., 16., 7., 0., 0., 0., 0., 0., 4., 7., 16., 7., 0., 0., 0., 0., 0., 4., 16., 9., 0., 0., 0., 5., 4., 12., 16., 4., 0., 0., 0., 9., 16., 16., 10., 0., 0.])
show_digits_plots(digitos, 0, 9)
show_digits(digitos, 0, 9)
Realiza-se um split randomico dos dados, com 60% dos dados para treino e 40% para teste.
X_train, X_test, y_train, y_test = cross_validation.train_test_split(digitos.data,
digitos.target,
test_size=.4)
model = svm.SVC(gamma=0.001, C=100.0)
model.fit(X_train, y_train)
pickle.dump(model, open('model', 'w'))
y_predito = model.predict(X_test)
print metrics.classification_report(y_test, y_predito)
precision recall f1-score support 0 1.00 0.99 0.99 73 1 0.97 1.00 0.99 68 2 1.00 1.00 1.00 73 3 1.00 1.00 1.00 74 4 0.99 1.00 0.99 78 5 1.00 0.97 0.99 69 6 1.00 0.99 0.99 68 7 1.00 1.00 1.00 74 8 0.96 0.96 0.96 70 9 0.96 0.97 0.97 72 avg / total 0.99 0.99 0.99 719
Matriz cujas colunas representam as classes preditas, e as linhas as classes reais
cm = metrics.confusion_matrix(y_test, y_predito)
cms = np.zeros((11, 11))
cms[0, :] = np.arange(-1, 10, 1)
cms[:, 0] = np.arange(-1, 10, 1).reshape((11,))
cms[1:, 1:] = cm
print cms
[[ -1. 0. 1. 2. 3. 4. 5. 6. 7. 8. 9.] [ 0. 72. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [ 1. 0. 68. 0. 0. 0. 0. 0. 0. 0. 0.] [ 2. 0. 0. 73. 0. 0. 0. 0. 0. 0. 0.] [ 3. 0. 0. 0. 74. 0. 0. 0. 0. 0. 0.] [ 4. 0. 0. 0. 0. 78. 0. 0. 0. 0. 0.] [ 5. 0. 0. 0. 0. 0. 67. 0. 0. 0. 2.] [ 6. 0. 0. 0. 0. 0. 0. 67. 0. 1. 0.] [ 7. 0. 0. 0. 0. 0. 0. 0. 74. 0. 0.] [ 8. 0. 2. 0. 0. 0. 0. 0. 0. 67. 1.] [ 9. 0. 0. 0. 0. 0. 0. 0. 0. 2. 70.]]
show_confusion_matrix(cm)