scikit-learn / scikit-learn

scikit-learn: machine learning in Python
https://scikit-learn.org
BSD 3-Clause "New" or "Revised" License
59.38k stars 25.24k forks source link

Dask Joblib backend cancels required futures #12315

Open mrocklin opened 5 years ago

mrocklin commented 5 years ago

Here is a failing example that uses the Dask joblib backend to parallelize a nested RandomForestClassifier within a RandomSearchCV

from scipy.stats import randint as sp_randint

from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV
from sklearn.datasets import load_digits
from sklearn.ensemble import RandomForestClassifier

from dask.distributed import Client

def test_random_forest_random_search():
    with Client(processes=False) as client:
        # Taken from http://scikit-learn.org/stable/auto_examples/model_selection/plot_randomized_search.html

        # get some data
        digits = load_digits()
        X, y = digits.data, digits.target

        # build a classifier
        clf = RandomForestClassifier(n_estimators=20)
        # specify parameters and distributions to sample from
        param_dist = {"max_depth": [3, None],
                      "max_features": sp_randint(1, 11),
                      "min_samples_split": sp_randint(2, 11),
                      "bootstrap": [True, False],
                      "criterion": ["gini", "entropy"]}

        n_iter_search = 20
        random_search = RandomizedSearchCV(clf, param_distributions=param_dist,
                                           n_iter=n_iter_search, cv=5)
        from sklearn.externals import joblib
        with joblib.parallel_backend('dask'):
            random_search.fit(X, y)

What happened

By placing a breakpoint here:

diff --git a/distributed/client.py b/distributed/client.py
index f9cf7f94..e47d0d16 100644
--- a/distributed/client.py
+++ b/distributed/client.py
@@ -314,6 +314,8 @@ class Future(WrappedKey):
     def release(self, _in_destructor=False):
         # NOTE: this method can be called from different threads
         # (see e.g. Client.get() or Future.__del__())
+        if not default_client().cluster.scheduler.story(self.key):
+            import pdb; pdb.set_trace()
         if not self._cleared and self.client.generation == self._generation:
             self._cleared = True
             try:

We find that Joblib has stopped the computation

https://github.com/scikit-learn/scikit-learn/blob/3804ccd2770ac4c026a823d019091917e4a2c70e/sklearn/externals/joblib/parallel.py#L1002-L1004

Which then goes ahead and clears out live futures

https://github.com/scikit-learn/scikit-learn/blob/3804ccd2770ac4c026a823d019091917e4a2c70e/sklearn/externals/joblib/_dask.py#L156-L159

This then fails in retrieve when we get the job result

https://github.com/scikit-learn/scikit-learn/blob/3804ccd2770ac4c026a823d019091917e4a2c70e/sklearn/externals/joblib/parallel.py#L901

I'm curious why we're stopping things while we still have live futures. Have we found a "good enough" result and are stopping early? If so, should we be getting the results of the submitted-but-discarded futures? Is there something that Dask can do to help here?

jnothman commented 5 years ago

Ping @ogrisel, @tomMoral

Shotgunosine commented 4 years ago

Can anyone provide suggestions for how I could test if this is what's causing scikit-learn/scikit-learn#15383?