import numpy as np
from sklearn.base import clone
from sklearn.datasets import load_breast_cancer, load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.feature_selection import RFE as skRFE
class RFE():
def __init__(self, estimator):
self.estimator = estimator
def fit(self, X, y):
n_features_to_select = X.shape[1] / 2
support = np.ones(X.shape[1], dtype=np.bool)
ranking = np.ones(X.shape[1], dtype=np.int)
while np.sum(support) > n_features_to_select:
est = clone(self.estimator)
est.fit(X[:, support], y)
if hasattr(est, "feature_importances_"):
importances = est.feature_importances_
elif hasattr(est, "coef_"):
if est.coef_.ndim == 1:
importances = np.abs(est.coef_)
else:
importances = np.linalg.norm(est.coef_, ord=1, axis=0)
cur_feature = np.arange(X.shape[1])[support][np.argmin(importances)]
support[cur_feature] = False
ranking[~support] += 1
self.support_ = support
self.ranking_ = ranking
self.estimator_ = clone(self.estimator)
self.estimator_.fit(X[:, support], y)
return self
def transform(self, X):
return X[:, self.support_]
X, y = load_breast_cancer(return_X_y=True)
clf = RandomForestClassifier(random_state=0)
est1 = RFE(estimator=clf).fit(X, y)
est2 = skRFE(estimator=clf).fit(X, y)
assert np.array_equal(est1.support_, est2.support_)
assert np.array_equal(est1.ranking_, est2.ranking_)
Xt1 = est1.transform(X)
Xt2 = est2.transform(X)
assert np.allclose(Xt1, Xt2)
X, y = load_breast_cancer(return_X_y=True)
clf = LogisticRegression(max_iter=15000, random_state=0)
est1 = RFE(estimator=clf).fit(X, y)
est2 = skRFE(estimator=clf).fit(X, y)
assert np.array_equal(est1.support_, est2.support_)
assert np.array_equal(est1.ranking_, est2.ranking_)
Xt1 = est1.transform(X)
Xt2 = est2.transform(X)
assert np.allclose(Xt1, Xt2)
X, y = load_iris(return_X_y=True)
clf = LogisticRegression(max_iter=15000, random_state=0)
est1 = RFE(estimator=clf).fit(X, y)
est2 = skRFE(estimator=clf).fit(X, y)
assert np.array_equal(est1.support_, est2.support_)
assert np.array_equal(est1.ranking_, est2.ranking_)
Xt1 = est1.transform(X)
Xt2 = est2.transform(X)
assert np.allclose(Xt1, Xt2)