이 노트북의 코드에 대한 설명은 반복 교차 검증 글을 참고하세요.
%load_ext watermark
%watermark -v -p sklearn,numpy,scipy
CPython 3.5.6 IPython 6.5.0 sklearn 0.20.1 numpy 1.15.2 scipy 1.1.0
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.datasets import load_iris
from sklearn.model_selection import cross_val_score, KFold, StratifiedKFold
from sklearn.linear_model import LogisticRegression
iris = load_iris()
logreg = LogisticRegression(solver='liblinear', multi_class='auto', max_iter=1000)
RepeatedKFold
kfold = KFold(n_splits=5)
scores = cross_val_score(logreg, iris.data, iris.target, cv=kfold)
scores, scores.mean()
(array([1. , 0.93333333, 0.43333333, 0.96666667, 0.43333333]), 0.7533333333333333)
from sklearn.model_selection import RepeatedKFold
rkfold = RepeatedKFold(n_splits=5, n_repeats=5, random_state=42)
scores = cross_val_score(logreg, iris.data, iris.target, cv=rkfold)
scores, scores.mean()
(array([1. , 0.93333333, 0.93333333, 0.96666667, 0.96666667, 0.96666667, 0.93333333, 1. , 1. , 0.83333333, 0.93333333, 0.9 , 0.96666667, 0.9 , 0.93333333, 0.96666667, 1. , 0.93333333, 0.93333333, 0.93333333, 0.96666667, 0.9 , 1. , 0.93333333, 0.93333333]), 0.9466666666666668)
plt.boxplot(scores)
plt.show()
RepeatedStratifiedKFold
skfold = StratifiedKFold(n_splits=5)
scores = cross_val_score(logreg, iris.data, iris.target, cv=skfold)
scores, scores.mean()
(array([1. , 0.96666667, 0.93333333, 0.9 , 1. ]), 0.9600000000000002)
from sklearn.model_selection import RepeatedStratifiedKFold
rskfold = RepeatedStratifiedKFold(n_splits=5, n_repeats=5, random_state=42)
scores = cross_val_score(logreg, iris.data, iris.target, cv=rskfold)
scores, scores.mean()
(array([0.96666667, 0.96666667, 0.96666667, 0.93333333, 0.96666667, 0.86666667, 0.96666667, 0.96666667, 0.93333333, 0.96666667, 1. , 1. , 0.93333333, 0.93333333, 0.93333333, 1. , 0.96666667, 0.96666667, 0.9 , 0.96666667, 0.96666667, 0.96666667, 1. , 0.9 , 0.96666667]), 0.9559999999999998)
plt.boxplot(scores)
plt.show()
from sklearn.model_selection import GridSearchCV, train_test_split
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=42)
param_grid = {'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]}
grid_search = GridSearchCV(logreg, param_grid, cv=rskfold, return_train_score=True, iid=False)
grid_search.fit(X_train, y_train)
GridSearchCV(cv=<sklearn.model_selection._split.RepeatedStratifiedKFold object at 0x7f5eddb72cc0>, error_score='raise-deprecating', estimator=LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True, intercept_scaling=1, max_iter=1000, multi_class='auto', n_jobs=None, penalty='l2', random_state=None, solver='liblinear', tol=0.0001, verbose=0, warm_start=False), fit_params=None, iid=False, n_jobs=None, param_grid={'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]}, pre_dispatch='2*n_jobs', refit=True, return_train_score=True, scoring=None, verbose=0)
grid_search.score(X_test, y_test), grid_search.best_params_, grid_search.best_score_
(1.0, {'C': 100}, 0.9605947675512894)
for k in grid_search.cv_results_:
if 'split' in k:
print(k, grid_search.cv_results_[k])
split15_test_score [0.34782609 0.65217391 0.7826087 0.91304348 0.95652174 0.95652174 0.95652174] split5_train_score [0.33707865 0.65168539 0.80898876 0.96629213 0.97752809 0.97752809 0.97752809] split14_test_score [0.33333333 0.66666667 0.80952381 0.9047619 0.95238095 0.9047619 0.9047619 ] split16_test_score [0.34782609 0.65217391 0.91304348 0.95652174 0.95652174 1. 0.95652174] split11_train_score [0.33707865 0.65168539 0.84269663 0.94382022 0.94382022 0.96629213 0.96629213] split18_train_score [0.34444444 0.65555556 0.82222222 0.95555556 0.96666667 0.96666667 0.96666667] split2_train_score [0.33707865 0.65168539 0.7752809 0.96629213 0.98876404 0.98876404 1. ] split13_test_score [0.31818182 0.63636364 0.86363636 1. 0.95454545 0.90909091 0.90909091] split21_test_score [0.34782609 0.65217391 0.82608696 0.91304348 0.86956522 0.91304348 0.86956522] split23_train_score [0.34444444 0.65555556 0.81111111 0.96666667 0.98888889 0.98888889 0.98888889] split17_train_score [0.33707865 0.65168539 0.82022472 0.94382022 0.95505618 0.96629213 0.96629213] split16_train_score [0.33707865 0.65168539 0.80898876 0.95505618 0.96629213 0.96629213 0.96629213] split1_test_score [0.34782609 0.65217391 0.86956522 0.82608696 0.91304348 0.95652174 1. ] split18_test_score [0.31818182 0.63636364 0.72727273 1. 1. 1. 0.90909091] split4_test_score [0.33333333 0.66666667 0.85714286 0.95238095 0.95238095 0.95238095 0.95238095] split17_test_score [0.34782609 0.65217391 0.82608696 0.95652174 1. 1. 1. ] split13_train_score [0.34444444 0.65555556 0.78888889 0.95555556 0.96666667 0.96666667 0.98888889] split11_test_score [0.34782609 0.65217391 0.7826087 0.95652174 0.95652174 1. 1. ] split10_train_score [0.33707865 0.65168539 0.83146067 0.96629213 0.96629213 0.96629213 0.97752809] split7_test_score [0.34782609 0.65217391 0.7826087 0.91304348 0.91304348 0.91304348 0.91304348] split8_train_score [0.34444444 0.65555556 0.76666667 0.94444444 0.97777778 0.98888889 0.97777778] split9_test_score [0.33333333 0.66666667 0.85714286 0.95238095 0.95238095 1. 1. ] split15_train_score [0.33707865 0.65168539 0.80898876 0.97752809 0.97752809 0.97752809 0.97752809] split7_train_score [0.33707865 0.65168539 0.84269663 0.95505618 0.96629213 0.96629213 0.96629213] split24_train_score [0.34065934 0.64835165 0.83516484 0.93406593 0.96703297 0.96703297 0.96703297] split20_test_score [0.34782609 0.65217391 0.82608696 0.95652174 1. 1. 1. ] split22_train_score [0.33707865 0.65168539 0.80898876 0.95505618 0.96629213 0.96629213 0.96629213] split19_test_score [0.33333333 0.66666667 0.80952381 0.95238095 0.9047619 0.9047619 0.9047619 ] split0_test_score [0.34782609 0.65217391 0.73913043 1. 0.95652174 0.95652174 0.95652174] split8_test_score [0.31818182 0.63636364 0.86363636 0.95454545 0.95454545 0.95454545 0.95454545] split4_train_score [0.34065934 0.64835165 0.79120879 0.96703297 0.96703297 0.97802198 0.97802198] split23_test_score [0.31818182 0.63636364 0.77272727 0.95454545 0.95454545 0.95454545 0.95454545] split10_test_score [0.34782609 0.65217391 0.82608696 0.91304348 0.95652174 1. 1. ] split9_train_score [0.34065934 0.64835165 0.83516484 0.93406593 0.95604396 0.95604396 0.96703297] split14_train_score [0.34065934 0.64835165 0.8021978 0.96703297 0.95604396 0.97802198 0.96703297] split6_train_score [0.33707865 0.65168539 0.83146067 0.95505618 0.95505618 0.97752809 0.96629213] split3_train_score [0.34444444 0.65555556 0.82222222 0.95555556 0.96666667 0.96666667 0.96666667] split19_train_score [0.34065934 0.64835165 0.81318681 0.95604396 0.96703297 0.98901099 0.97802198] split6_test_score [0.34782609 0.65217391 0.7826087 0.91304348 0.95652174 0.95652174 0.95652174] split24_test_score [0.33333333 0.66666667 0.76190476 1. 1. 1. 1. ] split2_test_score [0.34782609 0.65217391 0.86956522 0.91304348 0.91304348 0.91304348 0.91304348] split12_train_score [0.33707865 0.65168539 0.80898876 0.95505618 0.98876404 0.98876404 0.98876404] split20_train_score [0.33707865 0.65168539 0.79775281 0.96629213 0.96629213 0.96629213 0.96629213] split0_train_score [0.33707865 0.65168539 0.85393258 0.93258427 0.95505618 0.97752809 0.96629213] split5_test_score [0.34782609 0.65217391 0.7826087 0.95652174 0.95652174 0.95652174 0.95652174] split3_test_score [0.31818182 0.63636364 0.72727273 1. 1. 1. 1. ] split12_test_score [0.34782609 0.65217391 0.7826087 0.95652174 0.95652174 0.95652174 0.95652174] split1_train_score [0.33707865 0.65168539 0.82022472 0.97752809 0.96629213 0.96629213 0.96629213] split22_test_score [0.34782609 0.65217391 0.91304348 0.95652174 0.95652174 0.95652174 1. ] split21_train_score [0.33707865 0.65168539 0.82022472 0.95505618 0.96629213 0.96629213 0.97752809]