이 노트북의 코드에 대한 설명은 반복 교차 검증 글을 참고하세요.
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import cross_val_score, KFold, StratifiedKFold
from sklearn.linear_model import LogisticRegression
iris = load_iris()
logreg = LogisticRegression(solver='liblinear', multi_class='auto', max_iter=1000)
RepeatedKFold
kfold = KFold(n_splits=5)
scores = cross_val_score(logreg, iris.data, iris.target, cv=kfold)
scores, scores.mean()
(array([1. , 0.93333333, 0.43333333, 0.96666667, 0.43333333]), 0.7533333333333333)
from sklearn.model_selection import RepeatedKFold
rkfold = RepeatedKFold(n_splits=5, n_repeats=5, random_state=42)
scores = cross_val_score(logreg, iris.data, iris.target, cv=rkfold)
scores, scores.mean()
(array([1. , 0.93333333, 0.93333333, 0.96666667, 0.96666667, 0.96666667, 0.93333333, 1. , 1. , 0.83333333, 0.93333333, 0.9 , 0.96666667, 0.9 , 0.93333333, 0.96666667, 1. , 0.93333333, 0.93333333, 0.93333333, 0.96666667, 0.9 , 1. , 0.93333333, 0.93333333]), 0.9466666666666668)
plt.boxplot(scores)
plt.show()
RepeatedStratifiedKFold
skfold = StratifiedKFold(n_splits=5)
scores = cross_val_score(logreg, iris.data, iris.target, cv=skfold)
scores, scores.mean()
(array([1. , 0.96666667, 0.93333333, 0.9 , 1. ]), 0.9600000000000002)
from sklearn.model_selection import RepeatedStratifiedKFold
rskfold = RepeatedStratifiedKFold(n_splits=5, n_repeats=5, random_state=42)
scores = cross_val_score(logreg, iris.data, iris.target, cv=rskfold)
scores, scores.mean()
(array([0.96666667, 1. , 0.9 , 0.93333333, 1. , 0.96666667, 0.96666667, 0.96666667, 0.96666667, 0.93333333, 0.93333333, 1. , 1. , 0.93333333, 0.96666667, 0.96666667, 1. , 0.96666667, 0.9 , 0.96666667, 0.96666667, 0.9 , 0.96666667, 1. , 0.96666667]), 0.9613333333333334)
plt.boxplot(scores)
plt.show()
from sklearn.model_selection import GridSearchCV, train_test_split
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=42)
param_grid = {'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]}
grid_search = GridSearchCV(logreg, param_grid, cv=rskfold, return_train_score=True)
grid_search.fit(X_train, y_train)
GridSearchCV(cv=RepeatedStratifiedKFold(n_repeats=5, n_splits=5, random_state=42), estimator=LogisticRegression(max_iter=1000, solver='liblinear'), param_grid={'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]}, return_train_score=True)
grid_search.score(X_test, y_test), grid_search.best_params_, grid_search.best_score_
(1.0, {'C': 10}, 0.9640316205533597)
for k in grid_search.cv_results_:
if 'split' in k:
print(k, grid_search.cv_results_[k])
split0_test_score [0.34782609 0.65217391 0.86956522 0.95652174 1. 1. 0.95652174] split1_test_score [0.34782609 0.65217391 0.82608696 0.95652174 0.95652174 0.95652174 0.95652174] split2_test_score [0.36363636 0.68181818 0.86363636 0.90909091 0.90909091 0.86363636 0.86363636] split3_test_score [0.31818182 0.63636364 0.72727273 0.90909091 0.95454545 0.95454545 0.95454545] split4_test_score [0.31818182 0.63636364 0.81818182 1. 1. 1. 1. ] split5_test_score [0.34782609 0.65217391 0.86956522 0.95652174 0.95652174 0.95652174 0.95652174] split6_test_score [0.34782609 0.65217391 0.73913043 0.95652174 1. 1. 1. ] split7_test_score [0.36363636 0.68181818 0.90909091 0.95454545 0.95454545 0.95454545 0.95454545] split8_test_score [0.31818182 0.63636364 0.81818182 0.95454545 1. 1. 1. ] split9_test_score [0.31818182 0.63636364 0.68181818 0.86363636 0.90909091 0.90909091 0.95454545] split10_test_score [0.34782609 0.65217391 0.69565217 0.95652174 0.95652174 0.86956522 0.86956522] split11_test_score [0.34782609 0.65217391 0.82608696 0.95652174 1. 1. 1. ] split12_test_score [0.36363636 0.68181818 0.81818182 0.95454545 0.95454545 0.90909091 0.90909091] split13_test_score [0.31818182 0.63636364 0.81818182 0.95454545 0.90909091 0.90909091 0.90909091] split14_test_score [0.31818182 0.63636364 0.86363636 0.90909091 1. 1. 1. ] split15_test_score [0.34782609 0.65217391 0.73913043 0.95652174 0.95652174 0.91304348 0.86956522] split16_test_score [0.34782609 0.65217391 0.86956522 0.95652174 0.95652174 0.95652174 0.95652174] split17_test_score [0.36363636 0.68181818 0.86363636 1. 1. 1. 1. ] split18_test_score [0.31818182 0.63636364 0.86363636 1. 1. 1. 1. ] split19_test_score [0.31818182 0.63636364 0.68181818 0.90909091 0.90909091 0.95454545 0.95454545] split20_test_score [0.34782609 0.65217391 0.7826087 1. 1. 1. 1. ] split21_test_score [0.34782609 0.65217391 0.7826087 0.95652174 1. 0.95652174 0.91304348] split22_test_score [0.36363636 0.68181818 0.86363636 0.86363636 0.95454545 0.95454545 0.95454545] split23_test_score [0.31818182 0.63636364 0.77272727 1. 0.95454545 0.95454545 0.95454545] split24_test_score [0.31818182 0.63636364 0.86363636 0.90909091 0.90909091 0.90909091 0.90909091] split0_train_score [0.33707865 0.65168539 0.80898876 0.94382022 0.95505618 0.96629213 0.96629213] split1_train_score [0.33707865 0.65168539 0.79775281 0.96629213 0.98876404 0.98876404 0.98876404] split2_train_score [0.33333333 0.64444444 0.84444444 0.95555556 0.96666667 0.97777778 0.97777778] split3_train_score [0.34444444 0.65555556 0.82222222 0.95555556 0.96666667 0.96666667 0.96666667] split4_train_score [0.34444444 0.65555556 0.8 0.94444444 0.96666667 0.96666667 0.96666667] split5_train_score [0.33707865 0.65168539 0.80898876 0.96629213 0.97752809 0.97752809 0.97752809] split6_train_score [0.33707865 0.65168539 0.82022472 0.94382022 0.95505618 0.96629213 0.95505618] split7_train_score [0.33333333 0.64444444 0.86666667 0.94444444 0.96666667 0.97777778 0.96666667] split8_train_score [0.34444444 0.65555556 0.8 0.95555556 0.95555556 0.96666667 0.96666667] split9_train_score [0.34444444 0.65555556 0.83333333 0.96666667 0.97777778 0.98888889 0.98888889] split10_train_score [0.33707865 0.65168539 0.83146067 0.94382022 0.94382022 0.97752809 0.97752809] split11_train_score [0.33707865 0.65168539 0.79775281 0.96629213 0.96629213 0.96629213 0.96629213] split12_train_score [0.33333333 0.64444444 0.86666667 0.94444444 0.96666667 0.96666667 0.97777778] split13_train_score [0.34444444 0.65555556 0.8 0.96666667 0.98888889 0.98888889 1. ] split14_train_score [0.34444444 0.65555556 0.78888889 0.96666667 0.95555556 0.96666667 0.96666667] split15_train_score [0.33707865 0.65168539 0.85393258 0.94382022 0.94382022 0.97752809 0.97752809] split16_train_score [0.33707865 0.65168539 0.82022472 0.95505618 0.96629213 0.97752809 0.96629213] split17_train_score [0.33333333 0.64444444 0.83333333 0.95555556 0.96666667 0.96666667 0.96666667] split18_train_score [0.34444444 0.65555556 0.78888889 0.94444444 0.96666667 0.96666667 0.96666667] split19_train_score [0.34444444 0.65555556 0.83333333 0.96666667 0.98888889 0.98888889 0.97777778] split20_train_score [0.33707865 0.65168539 0.82022472 0.94382022 0.95505618 0.96629213 0.96629213] split21_train_score [0.33707865 0.65168539 0.80898876 0.96629213 0.97752809 0.96629213 0.96629213] split22_train_score [0.33333333 0.64444444 0.84444444 0.95555556 0.95555556 0.96666667 0.95555556] split23_train_score [0.34444444 0.65555556 0.81111111 0.95555556 0.96666667 0.97777778 0.96666667] split24_train_score [0.34444444 0.65555556 0.78888889 0.95555556 0.97777778 0.98888889 1. ]