class StratifiedShuffleSplit():
def __init__(self, n_splits=10,
train_size=0.9, test_size=0.1, random_state=0):
self.n_splits = n_splits
self.train_size = train_size
self.test_size = test_size
self.random_state = random_state
def _approximate_mode(self, class_counts, n_draws, rng):
continuous = n_draws * class_counts / class_counts.sum()
floored = np.floor(continuous)
need_to_add = int(n_draws - floored.sum())
if need_to_add > 0:
remainder = continuous - floored
values = np.sort(np.unique(remainder))[::-1]
for value in values:
inds = np.where(remainder == value)[0]
add_now = min(len(inds), need_to_add)
inds = rng.choice(inds, size=add_now, replace=False)
floored[inds] += 1
need_to_add -= add_now
if need_to_add == 0:
break
return floored.astype(int)
def split(self, X, y):
n_train = np.floor(self.train_size * X.shape[0])
n_test = np.ceil(self.test_size * X.shape[0])
classes, y_indices = np.unique(y, return_inverse=True)
class_counts = np.bincount(y_indices)
# quick sort is not stable
class_indices = np.split(np.argsort(y_indices, kind='mergesort'),
np.cumsum(class_counts)[:-1])
rng = np.random.RandomState(self.random_state)
for _ in range(self.n_splits):
train, test = [], []
n_i = self._approximate_mode(class_counts, n_train, rng)
t_i = self._approximate_mode(class_counts - n_i, n_test, rng)
for i in range(classes.shape[0]):
permutation = rng.permutation(class_counts[i])
train.extend(class_indices[i][permutation][:n_i[i]])
test.extend(class_indices[i][permutation][n_i[i]:n_i[i] + t_i[i]])
train = rng.permutation(train)
test = rng.permutation(test)
yield train, test