dask / dask-ml

Scalable Machine Learning with Dask
http://ml.dask.org
BSD 3-Clause "New" or "Revised" License
892 stars 255 forks source link

How to handle fit_kwargs in Incremental #262

Open TomAugspurger opened 6 years ago

TomAugspurger commented 6 years ago

Things like classes=da.unique(y) may be inefficient. This will have to be called on each block of data, which is expensive especially if the y isn't persisted.

Things like sample_weight are tricky. It's an array of n_samples that should actually be chunked along with X and y. We don't do this correctly right now.

from sklearn.linear_model import SGDClassifier
from dask_ml.datasets import make_classification
from dask_ml.wrappers import Incremental
import dask.array as da

X, y = make_classification(chunks=50)
sample_weight = da.random.uniform(size=len(X), chunks=50)

sgd = SGDClassifier(max_iter=1000)
inc = Incremental(sgd, scoring='accuracy')

inc.fit(X, y, classes=[0, 1], sample_weight=sample_weight)

raises with

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-3-4063d3b70a7b> in <module>()
     11 inc = Incremental(sgd, scoring='accuracy')
     12
---> 13 inc.fit(X, y, classes=[0, 1], sample_weight=sample_weight)

~/sandbox/dask-ml/dask_ml/wrappers.py in fit(self, X, y, **fit_kwargs)
    370     def fit(self, X, y=None, **fit_kwargs):
    371         estimator = sklearn.base.clone(self.estimator)
--> 372         self._fit_for_estimator(estimator, X, y, **fit_kwargs)
    373         return self
    374

~/sandbox/dask-ml/dask_ml/wrappers.py in _fit_for_estimator(self, estimator, X, y, **fit_kwargs)
    362             result = estimator.partial_fit(X=X, y=y, **fit_kwargs)
    363         else:
--> 364             result = fit(estimator, X, y, **fit_kwargs)
    365
    366         copy_learned_attributes(result, self)

~/sandbox/dask-ml/dask_ml/_partial.py in fit(model, x, y, compute, **kwargs)
    184
    185     if compute:
--> 186         return value.compute()
    187     else:
    188         return value

~/sandbox/dask/dask/base.py in compute(self, **kwargs)
    154         dask.base.compute
    155         """
--> 156         (result,) = compute(self, traverse=False, **kwargs)
    157         return result
    158

~/sandbox/dask/dask/base.py in compute(*args, **kwargs)
    400     keys = [x.__dask_keys__() for x in collections]
    401     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 402     results = schedule(dsk, keys, **kwargs)
    403     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    404

~/sandbox/dask/dask/threaded.py in get(dsk, result, cache, num_workers, **kwargs)
     73     results = get_async(pool.apply_async, len(pool._pool), dsk, result,
     74                         cache=cache, get_id=_thread_get_id,
---> 75                         pack_exception=pack_exception, **kwargs)
     76
     77     # Cleanup pools associated to dead threads

~/sandbox/dask/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    519                         _execute_task(task, data)  # Re-execute locally
    520                     else:
--> 521                         raise_exception(exc, tb)
    522                 res, worker_id = loads(res_info)
    523                 state['cache'][key] = res

~/sandbox/dask/dask/compatibility.py in reraise(exc, tb)
     67         if exc.__traceback__ is not tb:
     68             raise exc.with_traceback(tb)
---> 69         raise exc
     70
     71 else:

~/sandbox/dask/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    288     try:
    289         task, data = loads(task_info)
--> 290         result = _execute_task(task, data)
    291         id = get_id()
    292         result = dumps((result, id))

~/sandbox/dask/dask/local.py in _execute_task(arg, cache, dsk)
    269         func, args = arg[0], arg[1:]
    270         args2 = [_execute_task(a, cache) for a in args]
--> 271         return func(*args2)
    272     elif not ishashable(arg):
    273         return arg

~/sandbox/dask-ml/dask_ml/_partial.py in _partial_fit(model, x, y, kwargs)
    107     start = tic()
    108     logger.info("Starting partial-fit %s", dask.base.tokenize(model, x, y))
--> 109     model.partial_fit(x, y, **kwargs)
    110     stop = tic()
    111     logger.info("Finished partial-fit %s [%0.2f]",

~/sandbox/scikit-learn/sklearn/linear_model/stochastic_gradient.py in partial_fit(self, X, y, classes, sample_weight)
    557                                  learning_rate=self.learning_rate, max_iter=1,
    558                                  classes=classes, sample_weight=sample_weight,
--> 559                                  coef_init=None, intercept_init=None)
    560
    561     def fit(self, X, y, coef_init=None, intercept_init=None,

~/sandbox/scikit-learn/sklearn/linear_model/stochastic_gradient.py in _partial_fit(self, X, y, alpha, C, loss, learning_rate, max_iter, classes, sample_weight, coef_init, intercept_init)
    384         self._expanded_class_weight = compute_class_weight(self.class_weight,
    385                                                            self.classes_, y)
--> 386         sample_weight = self._validate_sample_weight(sample_weight, n_samples)
    387
    388         if getattr(self, "coef_", None) is None or coef_init is not None:

~/sandbox/scikit-learn/sklearn/linear_model/stochastic_gradient.py in _validate_sample_weight(self, sample_weight, n_samples)
    172                                        order="C")
    173         if sample_weight.shape[0] != n_samples:
--> 174             raise ValueError("Shapes of X and sample_weight do not match.")
    175         return sample_weight
    176

ValueError: Shapes of X and sample_weight do not match.

We don't want to persist that, as it may be too large.

TomAugspurger commented 6 years ago

We should handle this the same way as scikit-learn: fit params that are the same length as X are split just like X.

For dask dataframes, we could maybe rely on the heuristic of splitting the fit param when the number of divisions match X.

TomAugspurger commented 6 years ago

This also applies to incremental hyperparameter optimization in _fit

https://github.com/dask/dask-ml/blob/57690419995c29b09b8e54eea658ac83cc2dd15e/dask_ml/model_selection/_incremental.py#L133-L134