dask / dask-ml

Scalable Machine Learning with Dask
http://ml.dask.org
BSD 3-Clause "New" or "Revised" License
902 stars 256 forks source link

Revise documentation on hyperparameter search #88

Open jimmywan opened 7 years ago

jimmywan commented 7 years ago

Right around the time that I stumbled upon dask-searchcv, scikit-learn 0.19 was released where pipelines now support the memory parameter: http://scikit-learn.org/stable/whats_new.html#version-0-19

As such, perhaps this section of the docs should be revised:

With the regular scikit-learn version, each stage of the pipeline must be fit for each of the combinations of the parameters, even if that step isn’t being searched over.

I'd be interested to hear if any work was done to compare/contrast dask-searchcv vs scikit-learn's changes in 0.19. Presumably dask-searchcv allows you to more easily harness multiple machines, but perhaps some rudimentary benchmarks could/would be interesting/appropriate.

From: http://dask-ml.readthedocs.io/en/latest/hyper-parameter-search.html#efficient-search

jimmywan commented 7 years ago

Perhaps it would also make sense to change the benchmarks below to iterate over a small fixed set of values for random_state so that the benchmark is as apples to apples as possible?

mrocklin commented 7 years ago

@jimmywan it would be great to get updated benchmarks. Is this something that you would feel comfortable contributing?

jimmywan commented 7 years ago

My primary environment is a Ubuntu 14.04 VM running on a Windows 10 host so I was thinking you might want less variables involved, but if that's not a concern, I'd be happy to try and help out a bit. From what I have read, I like what you guys have been doing and/or the direction you're trying to go with dask-* so happy to help.

TomAugspurger commented 7 years ago

so I was thinking you might want less variables involved

I presume the differences would cancel eachother out, but perhaps not. Either way, I'd be happy to run the benchmark as well if you're able to put one together.

Also, this issue may interest you: https://github.com/scikit-learn/scikit-learn/issues/10068

jimmywan commented 7 years ago

Took a stab at this and here were the results. Hope this helps.

Results

scikit-learn GridSearchCV took 5.26 seconds for 16 candidate parameter settings.
scikit-learn GridSearchCV took 5.23 seconds for 16 candidate parameter settings.
scikit-learn GridSearchCV took 4.96 seconds for 16 candidate parameter settings.
scikit-learn GridSearchCV with memory took 5.55 seconds for 16 candidate parameter settings.
scikit-learn GridSearchCV with memory took 5.40 seconds for 16 candidate parameter settings.
scikit-learn GridSearchCV with memory took 4.90 seconds for 16 candidate parameter settings.
dashk-searchcv GridSearchCV took 6.63 seconds for 16 candidate parameter settings.
dashk-searchcv GridSearchCV took 5.59 seconds for 16 candidate parameter settings.
dashk-searchcv GridSearchCV took 5.96 seconds for 16 candidate parameter settings.
scikit-learn RandomizedSearchCV took 86.75 seconds for 250 candidates parameter settings.
scikit-learn RandomizedSearchCV with memory took 83.36 seconds for 250 candidates parameter settings.
dashk-searchcv RandomizedSearchCV took 75.33 seconds for 250 candidates parameter settings.

Benchmark code

I started with this, and made some modifications that I thought would make the benchmarks more comparable:

I didn't put a whole lot more effort into changing the pipeline. I'm not super familiar with classification tasks or this dataset.

if __name__ == '__main__':
    from dask_searchcv import GridSearchCV, RandomizedSearchCV
    from distributed import Client
    from scipy import stats
    from shutil import rmtree
    from sklearn import model_selection as ms
    from sklearn.datasets import load_digits
    from sklearn.linear_model import LogisticRegression
    from sklearn.pipeline import Pipeline
    from sklearn.preprocessing import StandardScaler
    from tempfile import mkdtemp
    from time import time

    def get_clf(random_state, memory_dir=None):
        clf = Pipeline(
                [
                    ('scaler', StandardScaler()),
                    ('clf', LogisticRegression(random_state=random_state)),
                ],
                memory=memory_dir
        )
        return clf

    client = Client()

    # get some data
    digits = load_digits()
    X, y = digits.data, digits.target

    # use a full grid over all parameters
    param_grid = {
        "C": [1e-5, 1e-3, 1e-1, 1],
        "fit_intercept": [True, False],
        "penalty": ["l1", "l2"]
    }
    param_grid = dict([('clf__' + x, y) for x, y in param_grid.items()])

    # scikit-learn
    for i in range(3):
        sk_grid_search = ms.GridSearchCV(get_clf(i), param_grid=param_grid, n_jobs=-1)
        start = time()
        sk_grid_search.fit(X, y)
        print("scikit-learn GridSearchCV took %.2f seconds for %d candidate parameter settings."
              % (time() - start, len(sk_grid_search.cv_results_['params'])))
        del sk_grid_search

    # scikit-learn w/ memory
    for i in range(3):
        tmp_dir = mkdtemp()
        try:
            sk_grid_search_with_memory = ms.GridSearchCV(get_clf(i, tmp_dir), param_grid=param_grid, n_jobs=-1)
            start = time()
            sk_grid_search_with_memory.fit(X, y)

            print("scikit-learn GridSearchCV with memory took %.2f seconds for %d candidate parameter settings."
                  % (time() - start, len(sk_grid_search_with_memory.cv_results_['params'])))
            del sk_grid_search_with_memory
        finally:
            rmtree(tmp_dir)

    # dask
    for i in range(3):
        dk_grid_search = GridSearchCV(get_clf(i), param_grid=param_grid, n_jobs=-1)
        start = time()
        dk_grid_search.fit(X, y)
        print("dashk-searchcv GridSearchCV took %.2f seconds for %d candidate parameter settings."
              % (time() - start, len(dk_grid_search.cv_results_['params'])))
        del dk_grid_search

    param_dist = {
        "C": stats.beta(1, 3),
        "fit_intercept": [True, False],
        "penalty": ["l1", "l2"]
    }
    param_dist = dict([('clf__' + x, y) for x, y in param_dist.items()])

    n_iter_search = 250

    # scikit-learn randomized search
    sk_random_search = ms.RandomizedSearchCV(get_clf(42), param_distributions=param_dist, n_iter=n_iter_search, n_jobs=-1, random_state=42)
    start = time()
    sk_random_search.fit(X, y)
    print("scikit-learn RandomizedSearchCV took %.2f seconds for %d candidates"
          " parameter settings." % ((time() - start), n_iter_search))
    del sk_random_search

    # scikit-learn randomized search with memory
    tmp_dir = mkdtemp()
    try:
        sk_random_search_with_memory = ms.RandomizedSearchCV(
                get_clf(42, tmp_dir), param_distributions=param_dist, n_iter=n_iter_search, n_jobs=-1, random_state=42
        )
        start = time()
        sk_random_search_with_memory.fit(X, y)
        print("scikit-learn RandomizedSearchCV with memory took %.2f seconds for %d candidates"
              " parameter settings." % ((time() - start), n_iter_search))
        del sk_random_search_with_memory
    finally:
        rmtree(tmp_dir)

    # dask randomized search
    dk_random_search = RandomizedSearchCV(get_clf(42), param_distributions=param_dist, n_iter=n_iter_search, n_jobs=-1, random_state=42)
    start = time()
    dk_random_search.fit(X, y)
    print("dashk-searchcv RandomizedSearchCV took %.2f seconds for %d candidates"
          " parameter settings." % ((time() - start), n_iter_search))
    del dk_random_search

Environment

Python virtualenv contents:

bokeh                     0.12.10                  py36_0    conda-forge
boto3                     1.4.7            py36h4cc92d5_0
botocore                  1.7.20           py36h085fff1_0
bumpversion               0.5.3                     <pip>
certifi                   2017.7.27.1      py36h8b7b77e_0
click                     6.7                      py36_0    conda-forge
cloudpickle               0.4.0                    py36_0    conda-forge
dask                      0.15.4                     py_0    conda-forge
dask-core                 0.15.4                     py_0    conda-forge
dask-glm                  0.1.0                         0    conda-forge
dask-ml                   0.3.1                    py36_0    conda-forge
dask-searchcv             0.1.0                      py_0    conda-forge
distributed               1.19.3                   py36_0    conda-forge
docutils                  0.14             py36hb0f60f5_0
heapdict                  1.0.0                    py36_0    conda-forge
intel-openmp              2018.0.0             h15fc484_7
jinja2                    2.9.6                    py36_0    conda-forge
jmespath                  0.9.3            py36hd3948f9_0
joblib                    0.11                     py36_0
libedit                   3.1                  heed3624_0
libffi                    3.2.1                h4deb6c0_3
libgcc-ng                 7.2.0                h7cc24e2_2
libgfortran               3.0.0                         1
libgfortran-ng            7.2.0                h9f7466a_2
libstdcxx-ng              7.2.0                h7a57d05_2
locket                    0.2.0                    py36_1    conda-forge
markupsafe                1.0                      py36_0    conda-forge
mkl                       2018.0.0             hb491cac_4
msgpack-python            0.4.8                    py36_0    conda-forge
multipledispatch          0.4.9                    py36_0    conda-forge
ncurses                   6.0                  h06874d7_1
nose                      1.3.7            py36hcdf7029_2
numpy                     1.13.3           py36ha12f23b_0
openssl                   1.0.2l                        0
pandas                    0.20.3           py36h842e28d_2
partd                     0.3.8                    py36_0    conda-forge
patsy                     0.4.1            py36ha3be15e_0
pip                       9.0.1            py36h8ec8b28_3
psutil                    5.4.0                    py36_0    conda-forge
python                    3.6.3                hcad60d5_0
python-dateutil           2.6.1            py36h88d3b88_1
pytz                      2017.2           py36hc2ccc2a_1
pyyaml                    3.12                     py36_1    conda-forge
readline                  7.0                  hac23ff0_3
s3fs                      0.1.2                    py36_0
s3transfer                0.1.10           py36h0257dcc_1
scikit-learn              0.19.1           py36h7aa7ec6_0
scipy                     0.19.1           py36h9976243_3
setuptools                36.5.0           py36he42e2e1_0
six                       1.10.0           py36hcac75e4_1
sortedcontainers          1.5.7                    py36_0    conda-forge
sqlite                    3.20.1               h6d8b0f3_1
statsmodels               0.8.0            py36h8533d0b_0
tblib                     1.3.2                    py36_0    conda-forge
tk                        8.6.7                h5979e9b_1
toolz                     0.8.2                     <pip>
toolz                     0.8.2                      py_2    conda-forge
tornado                   4.5.2                    py36_0    conda-forge
wheel                     0.29.0           py36he7f4e38_1
xz                        5.2.3                         0
yaml                      0.1.6                         0    conda-forge
zict                      0.1.3                      py_0    conda-forge
zlib                      1.2.11                        0

Host OS details:

Microsoft Windows [Version 10.0.16299.19]
(c) 2017 Microsoft Corporation. All rights reserved.

VM details:

VBoxManage.exe --version
5.1.24r117012

Guest OS details:


$ lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description:    Ubuntu 14.04 LTS
Release:        14.04
Codename:       trusty
TomAugspurger commented 7 years ago

Thanks @jimmywan! That'll be very helpful.

I'd like to add a "Benchmarks" or "Performance" section to docs/source/hyper-parameter-search.rst with

Are you interested in making a PR with that? Otherwise, I'll get to it by the end of the week.

jimmywan commented 7 years ago

@TomAugspurger I don't think I have time to put more work into it this week. Feel free to take and modify as you see fit.