class StratifiedKFold():
def __init__(self, n_splits=5, shuffle=False, random_state=0):
self.n_splits = n_splits
self.shuffle = shuffle
self.random_state = random_state
def _kfold(self, count, rng):
indices = np.arange(count)
if self.shuffle:
rng.shuffle(indices)
fold_sizes = np.full(self.n_splits, count // self.n_splits)
fold_sizes[:count % self.n_splits] += 1
current = 0
for fold_size in fold_sizes:
test_mask = np.zeros(count, dtype=bool)
test_mask[current:current + fold_size] = True
yield indices[test_mask]
current += fold_size
def _make_test_folds(self, X, y):
rng = np.random.RandomState(self.random_state)
unique_y, y_inversed = np.unique(y, return_inverse=True)
y_counts = np.bincount(y_inversed)
test_folds = np.zeros(X.shape[0])
per_cls_cvs = [self._kfold(count, rng) for count in y_counts]
test_folds = np.zeros(X.shape[0])
for test_fold_indices, per_cls_splits in enumerate(zip(*per_cls_cvs)):
for cls, test_split in zip(unique_y, per_cls_splits):
cls_test_folds = test_folds[y == cls]
cls_test_folds[test_split] = test_fold_indices
test_folds[y == cls] = cls_test_folds
return test_folds
def _iter_test_masks(self, X, y):
test_folds = self._make_test_folds(X, y)
for i in range(self.n_splits):
yield test_folds == i
def split(self, X, y):
indices = np.arange(X.shape[0])
for test_index in self._iter_test_masks(X, y):
yield indices[~test_index], indices[test_index]