dask / dask-searchcv

dask-searchcv is now part of dask-ml: https://github.com/dask/dask-ml
BSD 3-Clause "New" or "Revised" License
240 stars 43 forks source link

BaseSearchCV throws IndexError for particular sized optional arguments to BaseSearchCV.fit #74

Closed stsievert closed 6 years ago

stsievert commented 6 years ago

For certain sized optional arguments to estimator.fit, dask-searchcv throws an IndexError. It looks like it's trying to index the optional array with the number of examples.

Minimal working example:

from sklearn.base import BaseEstimator
import numpy as np
import dask_searchcv as dcv
from sklearn.datasets import make_classification
import pytest

class Dummy(BaseEstimator):
    def __init__(self, alpha=0):
        pass

    def fit(self, X, y, classes=None):
        return self

    def score(self, X, y):
        return 1

if __name__ == "__main__":
    X, y = make_classification(n_samples=25,
                               n_classes=2, random_state=0)

    clf = Dummy()
    grid = {'alpha': np.logspace(-3, 0)}
    classes = np.linspace(0, 1, num=24)

    gs = dcv.RandomizedSearchCV(clf, grid)

    with pytest.raises(IndexError):
        gs.fit(X, y, classes=classes)

    gs.fit(X, y)

Exceptions are raised only when len(classes) >= 25. As expected, the pass with scikit-learn and exceptions are also throw with dask_searchcv.GridSearchCV.

I ran into this issue while integrating #72.

Traceback when the pytest check is removed:

Traceback (most recent call last):
  File "test2.py", line 35, in <module>
    gs.fit(X, y, classes=classes)
  File "/Users/ssievert/anaconda3/lib/python3.6/site-packages/dask_searchcv/model_selection.py", line 867, in fit
    out = scheduler(dsk, keys, num_workers=n_jobs)
  File "/Users/ssievert/anaconda3/lib/python3.6/site-packages/dask/threaded.py", line 75, in get
    pack_exception=pack_exception, **kwargs)
  File "/Users/ssievert/anaconda3/lib/python3.6/site-packages/dask/local.py", line 521, in get_async
    raise_exception(exc, tb)
  File "/Users/ssievert/anaconda3/lib/python3.6/site-packages/dask/compatibility.py", line 67, in reraise
    raise exc
  File "/Users/ssievert/anaconda3/lib/python3.6/site-packages/dask/local.py", line 290, in execute_task
    result = _execute_task(task, data)
  File "/Users/ssievert/anaconda3/lib/python3.6/site-packages/dask/local.py", line 270, in _execute_task
    args2 = [_execute_task(a, cache) for a in args]
  File "/Users/ssievert/anaconda3/lib/python3.6/site-packages/dask/local.py", line 270, in <listcomp>
    args2 = [_execute_task(a, cache) for a in args]
  File "/Users/ssievert/anaconda3/lib/python3.6/site-packages/dask/local.py", line 271, in _execute_task
    return func(*args2)
  File "/Users/ssievert/anaconda3/lib/python3.6/site-packages/dask_searchcv/methods.py", line 141, in cv_extract_params
    return {k: cvs.extract_param(tok, v, n) for (k, tok), v in zip(keys, vals)}
  File "/Users/ssievert/anaconda3/lib/python3.6/site-packages/dask_searchcv/methods.py", line 141, in <dictcomp>
    return {k: cvs.extract_param(tok, v, n) for (k, tok), v in zip(keys, vals)}
  File "/Users/ssievert/anaconda3/lib/python3.6/site-packages/dask_searchcv/methods.py", line 93, in extract_param
    out = safe_indexing(x, self.splits[n][0]) if _is_arraylike(x) else x
  File "/Users/ssievert/anaconda3/lib/python3.6/site-packages/sklearn/utils/__init__.py", line 160, in safe_indexing
    return X.take(indices, axis=0)
IndexError: index 9 is out of bounds for size 2
TomAugspurger commented 6 years ago

Migrate this to Dask-ML?

stsievert commented 6 years ago

Closing for https://github.com/dask/dask-ml/issues/254