import numpy as np
from sklearn.metrics import f1_score as skf1_score
def f1_score(y_true, y_pred, average):
n_labels = len(set(y_true) | set(y_pred))
true_sum = np.bincount(y_true, minlength=n_labels)
pred_sum = np.bincount(y_pred, minlength=n_labels)
tp = np.bincount(y_true[y_true == y_pred], minlength=n_labels)
if average == "binary":
tp = np.array([tp[1]])
true_sum = np.array([true_sum[1]])
pred_sum = np.array([pred_sum[1]])
elif average == "micro":
tp = np.array([np.sum(tp)])
true_sum = np.array([np.sum(true_sum)])
pred_sum = np.array([np.sum(pred_sum)])
precision = np.zeros(len(pred_sum))
mask = pred_sum != 0
precision[mask] = tp[mask] / pred_sum[mask]
recall = np.zeros(len(true_sum))
mask = true_sum != 0
recall[mask] = tp[mask] / true_sum[mask]
denom = precision + recall
denom[denom == 0.] = 1
fscore = 2 * precision * recall / denom
if average == "weighted":
fscore = np.average(fscore, weights=true_sum)
elif average is not None:
fscore = np.mean(fscore)
return fscore
# binary
for i in range(10):
rng = np.random.RandomState(i)
y_true = rng.randint(2, size=10)
y_pred = rng.randint(2, size=10)
score1 = f1_score(y_true, y_pred, average="binary")
score2 = skf1_score(y_true, y_pred, average="binary")
assert np.isclose(score1, score2)
# multiclass
for i in range(10):
for average in (None, "micro", "macro", "weighted"):
rng = np.random.RandomState(i)
y_true = rng.randint(3, size=10)
y_pred = rng.randint(3, size=10)
score1 = f1_score(y_true, y_pred, average=average)
score2 = skf1_score(y_true, y_pred, average=average)
if average is None:
assert np.array_equal(score1, score2)
else:
assert np.isclose(score1, score2)
d:\github\scikit-learn\sklearn\metrics\classification.py:1430: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no true samples. 'recall', 'true', average, warn_for) d:\github\scikit-learn\sklearn\metrics\classification.py:1428: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples. 'precision', 'predicted', average, warn_for)