下ごしらえ
%matplotlib inline
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from sklearn import cross_validation as cv
from sklearn.utils import shuffle
from sklearn.svm import SVC
clf=SVC(probability=True)
from sklearn.datasets import load_iris
data=load_iris().data
target=load_iris().target
from sklearn import grid_search as grid
clf.get_params()
{'C': 1.0, 'cache_size': 200, 'class_weight': None, 'coef0': 0.0, 'decision_function_shape': None, 'degree': 3, 'gamma': 'auto', 'kernel': 'rbf', 'max_iter': -1, 'probability': True, 'random_state': None, 'shrinking': True, 'tol': 0.001, 'verbose': False}
parameters = {
'kernel':('linear', 'rbf'),
'C': np.linspace(1,10,5),
'gamma' : np.append(
np.logspace(-4,1,11).astype('object'),
'auto'
)
}
クロスバリデーションまでまとめて実行可能
デフォルトはcv=None
grid_clf=grid.GridSearchCV(clf,parameters,n_jobs=-1,cv=5)
中身
grid_clf.param_grid
{'C': array([ 1. , 3.25, 5.5 , 7.75, 10. ]), 'gamma': array([0.0001, 0.00031622776601683794, 0.001, 0.0031622776601683794, 0.01, 0.03162277660168379, 0.1, 0.31622776601683794, 1.0, 3.1622776601683795, 10.0, 'auto'], dtype=object), 'kernel': ('linear', 'rbf')}
フィッティング
data_shuffle,target_shuffle=shuffle(data,target)
grid_clf.fit(data_shuffle,target_shuffle)
GridSearchCV(cv=5, error_score='raise', estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, decision_function_shape=None, degree=3, gamma='auto', kernel='rbf', max_iter=-1, probability=True, random_state=None, shrinking=True, tol=0.001, verbose=False), fit_params={}, iid=True, n_jobs=-1, param_grid={'kernel': ('linear', 'rbf'), 'C': array([ 1. , 3.25, 5.5 , 7.75, 10. ]), 'gamma': array([0.0001, 0.00031622776601683794, 0.001, 0.0031622776601683794, 0.01, 0.03162277660168379, 0.1, 0.31622776601683794, 1.0, 3.1622776601683795, 10.0, 'auto'], dtype=object)}, pre_dispatch='2*n_jobs', refit=True, scoring=None, verbose=0)
探索結果はgrid_scores_
に入っている
pd.DataFrame(grid_clf.grid_scores_).head()
parameters | mean_validation_score | cv_validation_scores | |
---|---|---|---|
0 | {'kernel': 'linear', 'C': 1.0, 'gamma': 0.0001} | 0.973333 | [0.966666666667, 0.966666666667, 0.93333333333... |
1 | {'kernel': 'rbf', 'C': 1.0, 'gamma': 0.0001} | 0.920000 | [0.9, 0.9, 0.9, 0.933333333333, 0.966666666667] |
2 | {'kernel': 'linear', 'C': 1.0, 'gamma': 0.0003... | 0.973333 | [0.966666666667, 0.966666666667, 0.93333333333... |
3 | {'kernel': 'rbf', 'C': 1.0, 'gamma': 0.0003162... | 0.920000 | [0.9, 0.9, 0.9, 0.933333333333, 0.966666666667] |
4 | {'kernel': 'linear', 'C': 1.0, 'gamma': 0.001} | 0.973333 | [0.966666666667, 0.966666666667, 0.93333333333... |
ベストのモデルは別名が付いている
print(grid_clf.best_params_)
print(grid_clf.best_score_)
{'kernel': 'rbf', 'C': 7.75, 'gamma': 0.03162277660168379} 0.986666666667
iris(4次元)でRF使わないですが...
from sklearn.ensemble import RandomForestClassifier as RFC
clf_rf=RFC()
print(clf_rf.get_params())
{'oob_score': False, 'max_features': 'auto', 'criterion': 'gini', 'n_jobs': 1, 'max_depth': None, 'min_samples_split': 2, 'verbose': 0, 'min_samples_leaf': 1, 'bootstrap': True, 'max_leaf_nodes': None, 'class_weight': None, 'warm_start': False, 'min_weight_fraction_leaf': 0.0, 'random_state': None, 'n_estimators': 10}
parameters_rf={'n_estimators':[1*2**i for i in range(9)]}
grid_clf_rf=grid.GridSearchCV(clf_rf,parameters_rf,n_jobs=-1,cv=5)
grid_clf_rf.fit(data_shuffle,target_shuffle)
GridSearchCV(cv=5, error_score='raise', estimator=RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini', max_depth=None, max_features='auto', max_leaf_nodes=None, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1, oob_score=False, random_state=None, verbose=0, warm_start=False), fit_params={}, iid=True, n_jobs=-1, param_grid={'n_estimators': [1, 2, 4, 8, 16, 32, 64, 128, 256]}, pre_dispatch='2*n_jobs', refit=True, scoring=None, verbose=0)
pd.DataFrame(grid_clf_rf.grid_scores_)
parameters | mean_validation_score | cv_validation_scores | |
---|---|---|---|
0 | {'n_estimators': 1} | 0.920000 | [0.833333333333, 0.866666666667, 0.93333333333... |
1 | {'n_estimators': 2} | 0.906667 | [0.9, 0.833333333333, 0.833333333333, 0.966666... |
2 | {'n_estimators': 4} | 0.933333 | [0.9, 0.866666666667, 0.966666666667, 0.966666... |
3 | {'n_estimators': 8} | 0.953333 | [0.933333333333, 0.866666666667, 0.96666666666... |
4 | {'n_estimators': 16} | 0.953333 | [0.933333333333, 0.866666666667, 0.96666666666... |
5 | {'n_estimators': 32} | 0.953333 | [0.933333333333, 0.866666666667, 0.96666666666... |
6 | {'n_estimators': 64} | 0.953333 | [0.933333333333, 0.866666666667, 0.96666666666... |
7 | {'n_estimators': 128} | 0.953333 | [0.933333333333, 0.866666666667, 0.96666666666... |
8 | {'n_estimators': 256} | 0.953333 | [0.933333333333, 0.866666666667, 0.96666666666... |
@y__sama