dask / dask-xgboost

BSD 3-Clause "New" or "Revised" License
162 stars 43 forks source link

Provide informative error message on bad type input to predict #35

Open mrocklin opened 5 years ago

mrocklin commented 5 years ago

Any insight on what's up with tests?

On Thu, Feb 21, 2019 at 11:16 AM Tom Augspurger notifications@github.com wrote:

@TomAugspurger approved this pull request.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/dask/dask-xgboost/pull/35#pullrequestreview-206491367, or mute the thread https://github.com/notifications/unsubscribe-auth/AASszAlnFRDwCyCvkascoYdmJ8uZY0ODks5vPvBvgaJpZM4bIFhb .

TomAugspurger commented 5 years ago

I suspect the test_sparse one is similar to what we ran into with dask. IIRC sparse changed to be stricter about not converting to dense.

No idea about the other ones unfortunately :/ Possibly something with pytest-xdist?

FWIW, I have a local (unpunished) branch called test-fixup with this diff

diff --git a/.circleci/config.yml b/.circleci/config.yml
index f1463079..72faf516 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -16,7 +16,7 @@ jobs:
             conda config --add channels conda-forge
             conda create -q -n test-environment python=${PYTHON}
             source activate test-environment
-            conda install -q coverage flake8 pytest pytest-cov pytest-xdist numpy pandas xgboost dask distributed scikit-learn sparse scipy
+            conda install -q coverage flake8 pytest pytest-cov numpy pandas xgboost dask distributed scikit-learn sparse scipy
             pip install -e .
             conda list test-environment
       - run:
diff --git a/dask_xgboost/core.py b/dask_xgboost/core.py
index 6bf29d78..c843a000 100644
--- a/dask_xgboost/core.py
+++ b/dask_xgboost/core.py
@@ -34,7 +34,7 @@ def parse_host_port(address):
     return host, port

-def start_tracker(host, n_workers):
+def start_tracker(host, n_workers, dask_scheduler=None):
     """ Start Rabit tracker """
     env = {'DMLC_NUM_WORKER': n_workers}
     rabit = RabitTracker(hostIP=host, nslave=n_workers)
@@ -45,6 +45,7 @@ def start_tracker(host, n_workers):
     thread = Thread(target=rabit.join)
     thread.daemon = True
     thread.start()
+    dask_scheduler.xgboost_thread = thread
     return env

@@ -155,6 +156,13 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
     num_class = params.get("num_class")
     if num_class:
         result.set_attr(num_class=str(num_class))
+
+    def wait_on_tracker_thread(dask_scheduler):
+        dask_scheduler.xgboost_thread.join()
+        del dask_scheduler.xgboost_thread
+
+    yield client.run_on_scheduler(wait_on_tracker_thread)
+
     raise gen.Return(result)

diff --git a/setup.cfg b/setup.cfg
index 2348f495..11894603 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -5,4 +5,4 @@ universal=1
 exclude = tests/data,docs,benchmarks,scripts

 [tool:pytest]
-addopts = -rsx -v -n 1 --boxed
+addopts = -rsx -v

Looking further, that looks like https://github.com/dask/dask-xgboost/pull/29#issuecomment-430596828