Comparision of dask_glm and scikit-learn on the SUSY dataset.

In [2]:
import numpy as np
import pandas as pd

import dask
from distributed import Client
import dask.array as da
from sklearn import linear_model
from dask_glm.estimators import LogisticRegression
In [3]:
df = pd.read_csv("SUSY.csv.gz", header=None)
df.head()
Out[3]:
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
0 0.0 0.972861 0.653855 1.176225 1.157156 -1.739873 -0.874309 0.567765 -0.175000 0.810061 -0.252552 1.921887 0.889637 0.410772 1.145621 1.932632 0.994464 1.367815 0.040714
1 1.0 1.667973 0.064191 -1.225171 0.506102 -0.338939 1.672543 3.475464 -1.219136 0.012955 3.775174 1.045977 0.568051 0.481928 0.000000 0.448410 0.205356 1.321893 0.377584
2 1.0 0.444840 -0.134298 -0.709972 0.451719 -1.613871 -0.768661 1.219918 0.504026 1.831248 -0.431385 0.526283 0.941514 1.587535 2.024308 0.603498 1.562374 1.135454 0.180910
3 1.0 0.381256 -0.976145 0.693152 0.448959 0.891753 -0.677328 2.033060 1.533041 3.046260 -1.005285 0.569386 1.015211 1.582217 1.551914 0.761215 1.715464 1.492257 0.090719
4 1.0 1.309996 -0.690089 -0.676259 1.589283 -0.693326 0.622907 1.087562 -0.381742 0.589204 1.365479 1.179295 0.968218 0.728563 0.000000 1.083158 0.043429 1.154854 0.094859
In [4]:
len(df)
Out[4]:
5000000

We have 5,000,000 rows of all-numeric data. We'll skip any feature engineering and preprocessing.

In [5]:
y = df[0].values
X = df.drop(0, axis=1).values
In [6]:
C = 10     # for scikit-learn
λ = 1 / C  # for dask_glm
In [7]:
from sklearn.preprocessing import scale

X = scale(X)

Scikit-learn

First, we run scikit-learn's LogisticRegression on the full dataset.

In [8]:
%%time
lm = linear_model.LogisticRegression(penalty='l1', C=C, solver='saga')
lm.fit(X, y)
CPU times: user 1min 8s, sys: 52 ms, total: 1min 8s
Wall time: 1min 8s
In [9]:
%%time
lm.score(X, y)
CPU times: user 936 ms, sys: 316 ms, total: 1.25 s
Wall time: 780 ms
Out[9]:
0.78832060000000004
In [10]:
# %%time
# lm = linear_model.LogisticRegression(penalty='l1', C=C)
# lm.fit(X, y)
In [11]:
# %%time
# lm.score(X, y)
In [12]:
lm.coef_
Out[12]:
array([[  1.59889253e+00,  -2.16531507e-04,  -2.07209382e-03,
          3.07578963e-01,   1.31631754e-03,  -7.76999581e-05,
          4.08804094e+00,   3.57268409e-03,  -3.64687752e-01,
          3.18132406e-01,   1.27980300e-01,  -9.35147191e-01,
         -8.07122679e-01,   8.41146611e-02,  -1.26453868e+00,
          3.32235015e-01,  -2.71631453e-01,   2.17774653e-01]])

Dask GLM

Now for the dask-glm version.

In [13]:
client = Client()

# dask
K = 100000
dX = da.from_array(X, chunks=(K, X.shape[-1]))
dy = da.from_array(y, chunks=(K,))

dX, dy = dask.persist(X, y)
client.rebalance([X, y])
distributed.deploy.local - INFO - To start diagnostics web server please install Bokeh
In [14]:
%%time
dk = LogisticRegression()
dk.fit(dX, dy)
Converged! 6
CPU times: user 17.2 s, sys: 26 s, total: 43.1 s
Wall time: 6min 3s
In [15]:
%%time
dk.score(dX, dy)
CPU times: user 532 ms, sys: 328 ms, total: 860 ms
Wall time: 438 ms
Out[15]:
0.78832460000000004
In [16]:
dk.coef_
Out[16]:
array([  1.58671772e+00,  -1.90417079e-04,  -2.01987034e-03,
         3.04813020e-01,   1.38868265e-03,  -1.10436676e-04,
         4.05569762e+00,   3.50754224e-03,  -3.61962379e-01,
         3.15746332e-01,   1.26745609e-01,  -9.27452934e-01,
        -8.00558632e-01,   8.32899295e-02,  -1.25427689e+00,
         3.29617974e-01,  -2.69505213e-01,   2.15981990e-01])
Library Training time Score
dask-glm 1:08 .788
scikit-learn 6:01 .788

The saga fit is not perfect though (accuracy is slightly lower and the coefficients not identical):

In [19]:
np.max(np.abs(dk.coef_ - lm.coef_))
Out[19]:
0.032343321467148911
In [20]:
np.abs(dk.coef_ - lm.coef_)
Out[20]:
array([[  1.21748057e-02,   2.61144277e-05,   5.22234856e-05,
          2.76594302e-03,   7.23651159e-05,   3.27367177e-05,
          3.23433215e-02,   6.51418546e-05,   2.72537334e-03,
          2.38607405e-03,   1.23469089e-03,   7.69425681e-03,
          6.56404753e-03,   8.24731672e-04,   1.02617860e-02,
          2.61704189e-03,   2.12624013e-03,   1.79266276e-03]])
In [ ]: