import numpy as np
from math import ceil, floor
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split as sktrain_test_split
class ShuffleSplit():
def __init__(self, n_splits=10,
train_size=0.9, test_size=0.1, random_state=0):
self.n_splits = n_splits
self.train_size = train_size
self.test_size = test_size
self.random_state = random_state
def split(self, X, y=None):
n_train = floor(self.train_size * X.shape[0])
n_test = ceil(self.test_size * X.shape[0])
rng = np.random.RandomState(self.random_state)
for _ in range(self.n_splits):
permutation = rng.permutation(X.shape[0])
yield (permutation[n_test:(n_test + n_train)],
permutation[:n_test])
class StratifiedShuffleSplit():
def __init__(self, n_splits=10,
train_size=0.9, test_size=0.1, random_state=0):
self.n_splits = n_splits
self.train_size = train_size
self.test_size = test_size
self.random_state = random_state
def _approximate_mode(self, class_counts, n_draws, rng):
continuous = n_draws * class_counts / class_counts.sum()
floored = np.floor(continuous)
need_to_add = int(n_draws - floored.sum())
if need_to_add > 0:
remainder = continuous - floored
values = np.sort(np.unique(remainder))[::-1]
for value in values:
inds = np.where(remainder == value)[0]
add_now = min(len(inds), need_to_add)
inds = rng.choice(inds, size=add_now, replace=False)
floored[inds] += 1
need_to_add -= add_now
if need_to_add == 0:
break
return floored.astype(int)
def split(self, X, y):
n_train = np.floor(self.train_size * X.shape[0])
n_test = np.ceil(self.test_size * X.shape[0])
classes, y_indices = np.unique(y, return_inverse=True)
class_counts = np.bincount(y_indices)
# quick sort is not stable
class_indices = np.split(np.argsort(y_indices, kind='mergesort'),
np.cumsum(class_counts)[:-1])
rng = np.random.RandomState(self.random_state)
for _ in range(self.n_splits):
train, test = [], []
n_i = self._approximate_mode(class_counts, n_train, rng)
t_i = self._approximate_mode(class_counts - n_i, n_test, rng)
for i in range(classes.shape[0]):
permutation = rng.permutation(class_counts[i])
train.extend(class_indices[i][permutation][:n_i[i]])
test.extend(class_indices[i][permutation][n_i[i]:n_i[i] + t_i[i]])
train = rng.permutation(train)
test = rng.permutation(test)
yield train, test
def train_test_split(X, y, train_size=0.75, test_size=0.25,
random_state=0, stratify=None):
if stratify is not None:
cv = StratifiedShuffleSplit(train_size=train_size, test_size=test_size, random_state=0)
train, test = next(cv.split(X, stratify))
else:
cv = ShuffleSplit(train_size=train_size, test_size=test_size, random_state=0)
train, test = next(cv.split(X))
return X[train], X[test], y[train], y[test]
X, y = load_iris(return_X_y=True)
X_train_1, X_test_1, y_train_1, y_test_1 = train_test_split(X, y, random_state=0)
X_train_2, X_test_2, y_train_2, y_test_2 = sktrain_test_split(X, y, random_state=0)
assert np.allclose(X_train_1, X_train_2)
assert np.allclose(X_test_1, X_test_2)
assert np.array_equal(y_train_1, y_train_2)
assert np.array_equal(y_test_1, y_test_2)
X, y = load_iris(return_X_y=True)
X_train_1, X_test_1, y_train_1, y_test_1 = train_test_split(X, y, train_size=0.5,
test_size=0.2, random_state=0)
X_train_2, X_test_2, y_train_2, y_test_2 = sktrain_test_split(X, y, train_size=0.5,
test_size=0.2, random_state=0)
assert np.allclose(X_train_1, X_train_2)
assert np.allclose(X_test_1, X_test_2)
assert np.array_equal(y_train_1, y_train_2)
assert np.array_equal(y_test_1, y_test_2)
X, y = load_iris(return_X_y=True)
X_train_1, X_test_1, y_train_1, y_test_1 = train_test_split(X, y, random_state=0, stratify=y)
X_train_2, X_test_2, y_train_2, y_test_2 = sktrain_test_split(X, y, random_state=0, stratify=y)
assert np.allclose(X_train_1, X_train_2)
assert np.allclose(X_test_1, X_test_2)
assert np.array_equal(y_train_1, y_train_2)
assert np.array_equal(y_test_1, y_test_2)
X, y = load_iris(return_X_y=True)
X_train_1, X_test_1, y_train_1, y_test_1 = train_test_split(X, y, train_size=0.5,
test_size=0.2, random_state=0, stratify=y)
X_train_2, X_test_2, y_train_2, y_test_2 = sktrain_test_split(X, y, train_size=0.5,
test_size=0.2, random_state=0, stratify=y)
assert np.allclose(X_train_1, X_train_2)
assert np.allclose(X_test_1, X_test_2)
assert np.array_equal(y_train_1, y_train_2)
assert np.array_equal(y_test_1, y_test_2)