import numpy as np
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import SGDClassifier as skSGDClassifier
def _loss(x, y, coef, intercept):
p = np.dot(x, coef) + intercept
z = p * y
if z <= 1:
return 1 - z
else:
return 0
def _grad(x, y, coef, intercept):
p = np.dot(x, coef) + intercept
z = p * y
if z <= 1:
dloss = -y
else:
dloss = 0
# clip gradient (consistent with scikit-learn)
dloss = np.clip(dloss, -1e12, 1e12)
coef_grad = dloss * x
intercept_grad = dloss
return coef_grad, intercept_grad
class SGDClassifier():
def __init__(self, penalty="l2", alpha=0.0001, max_iter=1000, tol=1e-3,
shuffle=True, random_state=0,
# use learning_rate = 'invscaling' for simplicity
eta0=0, power_t=0.5, n_iter_no_change=5):
self.penalty = penalty
self.alpha = alpha
self.max_iter = max_iter
self.tol = tol
self.shuffle = shuffle
self.random_state = random_state
self.eta0 = eta0
self.power_t = power_t
self.n_iter_no_change = n_iter_no_change
def _encode(self, y):
classes = np.unique(y)
y_train = np.full((y.shape[0], len(classes)), -1)
for i, c in enumerate(classes):
y_train[y == c, i] = 1
if len(classes) == 2:
y_train = y_train[:, 1].reshape(-1, 1)
return classes, y_train
def fit(self, X, y):
self.classes_, y_train = self._encode(y)
if len(self.classes_) == 2:
coef = np.zeros((1, X.shape[1]))
intercept = np.zeros(1)
else:
coef = np.zeros((len(self.classes_), X.shape[1]))
intercept = np.zeros(len(self.classes_))
n_iter = 0
rng = np.random.RandomState(self.random_state)
for class_ind in range(y_train.shape[1]):
cur_y = y_train[:, class_ind]
cur_coef = np.zeros(X.shape[1])
cur_intercept = 0
best_loss = np.inf
no_improvement_count = 0
t = 1
for epoch in range(self.max_iter):
# different from how data is shuffled in scikit-learn
if self.shuffle:
ind = rng.permutation(X.shape[0])
X, cur_y = X[ind], cur_y[ind]
sumloss = 0
for i in range(X.shape[0]):
sumloss += _loss(X[i], cur_y[i], cur_coef, cur_intercept)
eta = self.eta0 / np.power(t, self.power_t)
coef_grad, intercept_grad = _grad(X[i], cur_y[i], cur_coef, cur_intercept)
if self.penalty == "l2":
cur_coef *= 1 - eta * self.alpha
cur_coef -= eta * coef_grad
cur_intercept -= eta * intercept_grad
t += 1
if sumloss > best_loss - self.tol * X.shape[0]:
no_improvement_count += 1
else:
no_improvement_count = 0
if no_improvement_count == self.n_iter_no_change:
break
if sumloss < best_loss:
best_loss = sumloss
coef[class_ind] = cur_coef
intercept[class_ind] = cur_intercept
n_iter = max(n_iter, epoch + 1)
self.coef_ = coef
self.intercept_ = intercept
self.n_iter_ = n_iter
return self
def decision_function(self, X):
scores = np.dot(X, self.coef_.T) + self.intercept_
if scores.shape[1] == 1:
return scores.ravel()
else:
return scores
def predict(self, X):
scores = self.decision_function(X)
if len(scores.shape) == 1:
indices = (scores > 0).astype(int)
else:
indices = np.argmax(scores, axis=1)
return self.classes_[indices]
# binary classification
X, y = load_iris(return_X_y=True)
X, y = X[y != 2], y[y != 2]
X = StandardScaler().fit_transform(X)
clf1 = SGDClassifier(eta0=0.1, shuffle=False).fit(X, y)
clf2 = skSGDClassifier(learning_rate='invscaling', eta0=0.1, shuffle=False).fit(X, y)
assert np.allclose(clf1.coef_, clf2.coef_)
assert np.allclose(clf1.intercept_, clf2.intercept_)
prob1 = clf1.decision_function(X)
prob2 = clf2.decision_function(X)
assert np.allclose(prob1, prob2)
pred1 = clf1.predict(X)
pred2 = clf2.predict(X)
assert np.array_equal(pred1, pred2)
# shuffle=False penalty="none"
X, y = load_iris(return_X_y=True)
X = StandardScaler().fit_transform(X)
clf1 = SGDClassifier(eta0=0.1, shuffle=False).fit(X, y)
clf2 = skSGDClassifier(learning_rate='invscaling', eta0=0.1, shuffle=False).fit(X, y)
assert np.allclose(clf1.coef_, clf2.coef_)
assert np.allclose(clf1.intercept_, clf2.intercept_)
prob1 = clf1.decision_function(X)
prob2 = clf2.decision_function(X)
assert np.allclose(prob1, prob2)
pred1 = clf1.predict(X)
pred2 = clf2.predict(X)
assert np.array_equal(pred1, pred2)
# shuffle=False penalty="l2"
for alpha in [0.1, 1, 10]:
X, y = load_iris(return_X_y=True)
X = StandardScaler().fit_transform(X)
clf1 = SGDClassifier(eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)
clf2 = skSGDClassifier(learning_rate='invscaling', eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)
assert np.allclose(clf1.coef_, clf2.coef_)
assert np.allclose(clf1.intercept_, clf2.intercept_)
prob1 = clf1.decision_function(X)
prob2 = clf2.decision_function(X)
assert np.allclose(prob1, prob2)
pred1 = clf1.predict(X)
pred2 = clf2.predict(X)
assert np.array_equal(pred1, pred2)
def _loss(x, y, coef, intercept):
p = np.dot(x, coef) + intercept
z = 1 - p * y
if z > 0:
return z * z
else:
return 0
def _grad(x, y, coef, intercept):
p = np.dot(x, coef) + intercept
z = 1 - p * y
if z > 0:
dloss = -2 * y * z
else:
dloss = 0
# clip gradient (consistent with scikit-learn)
dloss = np.clip(dloss, -1e12, 1e12)
coef_grad = dloss * x
intercept_grad = dloss
return coef_grad, intercept_grad
# shuffle=False penalty="none"
X, y = load_iris(return_X_y=True)
X = StandardScaler().fit_transform(X)
clf1 = SGDClassifier(eta0=0.1, shuffle=False).fit(X, y)
clf2 = skSGDClassifier(loss="squared_hinge", learning_rate='invscaling', eta0=0.1, shuffle=False).fit(X, y)
assert np.allclose(clf1.coef_, clf2.coef_)
assert np.allclose(clf1.intercept_, clf2.intercept_)
prob1 = clf1.decision_function(X)
prob2 = clf2.decision_function(X)
assert np.allclose(prob1, prob2)
pred1 = clf1.predict(X)
pred2 = clf2.predict(X)
assert np.array_equal(pred1, pred2)
# shuffle=False penalty="l2"
for alpha in [0.1, 1, 10]:
X, y = load_iris(return_X_y=True)
X = StandardScaler().fit_transform(X)
clf1 = SGDClassifier(eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)
clf2 = skSGDClassifier(loss="squared_hinge", learning_rate='invscaling', eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)
assert np.allclose(clf1.coef_, clf2.coef_)
assert np.allclose(clf1.intercept_, clf2.intercept_)
prob1 = clf1.decision_function(X)
prob2 = clf2.decision_function(X)
assert np.allclose(prob1, prob2)
pred1 = clf1.predict(X)
pred2 = clf2.predict(X)
assert np.array_equal(pred1, pred2)
def _loss(x, y, coef, intercept):
p = np.dot(x, coef) + intercept
z = p * y
if z > 1:
return 0
elif z > -1:
return (1 - z) * (1 - z)
else:
return -4 * z
def _grad(x, y, coef, intercept):
p = np.dot(x, coef) + intercept
z = p * y
if z > 1:
dloss = 0
elif z > -1:
dloss = -2 * (1 - z) * y
else:
dloss = -4 * y
# clip gradient (consistent with scikit-learn)
dloss = np.clip(dloss, -1e12, 1e12)
coef_grad = dloss * x
intercept_grad = dloss
return coef_grad, intercept_grad
# shuffle=False penalty="none"
X, y = load_iris(return_X_y=True)
X = StandardScaler().fit_transform(X)
clf1 = SGDClassifier(eta0=0.1, shuffle=False).fit(X, y)
clf2 = skSGDClassifier(loss="modified_huber", learning_rate='invscaling', eta0=0.1, shuffle=False).fit(X, y)
assert np.allclose(clf1.coef_, clf2.coef_)
assert np.allclose(clf1.intercept_, clf2.intercept_)
prob1 = clf1.decision_function(X)
prob2 = clf2.decision_function(X)
assert np.allclose(prob1, prob2)
pred1 = clf1.predict(X)
pred2 = clf2.predict(X)
assert np.array_equal(pred1, pred2)
# shuffle=False penalty="l2"
for alpha in [0.1, 1, 10]:
X, y = load_iris(return_X_y=True)
X = StandardScaler().fit_transform(X)
clf1 = SGDClassifier(eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)
clf2 = skSGDClassifier(loss="modified_huber", learning_rate='invscaling', eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)
assert np.allclose(clf1.coef_, clf2.coef_)
assert np.allclose(clf1.intercept_, clf2.intercept_)
prob1 = clf1.decision_function(X)
prob2 = clf2.decision_function(X)
assert np.allclose(prob1, prob2)
pred1 = clf1.predict(X)
pred2 = clf2.predict(X)
assert np.array_equal(pred1, pred2)
def _loss(x, y, coef, intercept):
p = np.dot(x, coef) + intercept
z = p * y
# follow scikit-learn
if z > 18:
return np.exp(-z)
elif z < -18:
return -z
else:
return np.log(1 + np.exp(-z))
def _grad(x, y, coef, intercept):
p = np.dot(x, coef) + intercept
z = p * y
if z > 18:
dloss = -np.exp(-z) * y
elif z < -18:
dloss = -y
else:
dloss = -y / (1 + np.exp(z))
# clip gradient (consistent with scikit-learn)
dloss = np.clip(dloss, -1e12, 1e12)
coef_grad = dloss * x
intercept_grad = dloss
return coef_grad, intercept_grad
# shuffle=False penalty="none"
X, y = load_iris(return_X_y=True)
X = StandardScaler().fit_transform(X)
clf1 = SGDClassifier(eta0=0.1, shuffle=False).fit(X, y)
clf2 = skSGDClassifier(loss="log", learning_rate='invscaling', eta0=0.1, shuffle=False).fit(X, y)
assert np.allclose(clf1.coef_, clf2.coef_)
assert np.allclose(clf1.intercept_, clf2.intercept_)
prob1 = clf1.decision_function(X)
prob2 = clf2.decision_function(X)
assert np.allclose(prob1, prob2)
pred1 = clf1.predict(X)
pred2 = clf2.predict(X)
assert np.array_equal(pred1, pred2)
# shuffle=False penalty="l2"
for alpha in [0.1, 1, 10]:
X, y = load_iris(return_X_y=True)
X = StandardScaler().fit_transform(X)
clf1 = SGDClassifier(eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)
clf2 = skSGDClassifier(loss="log", learning_rate='invscaling', eta0=0.1, alpha=alpha, shuffle=False).fit(X, y)
assert np.allclose(clf1.coef_, clf2.coef_)
assert np.allclose(clf1.intercept_, clf2.intercept_)
prob1 = clf1.decision_function(X)
prob2 = clf2.decision_function(X)
assert np.allclose(prob1, prob2)
pred1 = clf1.predict(X)
pred2 = clf2.predict(X)
assert np.array_equal(pred1, pred2)