Describe the bug
When run with CUDA_VISIBLE_DEVICES='' and within a using_device_type('cpu') block, experimental FIL throws a CUDA Error on any predict call. This makes CPU FIL effectively unusable without a GPU because the cuML CPU package does not currently include CPU FIL. I believe this is a regression from when CPU FIL was introduced due to an upstream change, but I am not certain. It can likely be fixed by calling synchronize at line 300 of fil.pyx only if the current device type is GPU.
Error output below:
File "/raid/whicks/proj_xgboost/taxi_example/benchmark.py", line 282, in <module>
fil_model.optimize(batch_size=features.shape[0])
File "/raid/whicks/miniforge3/envs/triton_benchmark/lib/python3.12/site-packages/cuml/internals/api_decorators.py", line 190, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "fil.pyx", line 1470, in cuml.experimental.fil.fil.ForestInference.optimize
File "/raid/whicks/miniforge3/envs/triton_benchmark/lib/python3.12/site-packages/cuml/internals/api_decorators.py", line 188, in wrapper
ret = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/raid/whicks/miniforge3/envs/triton_benchmark/lib/python3.12/site-packages/nvtx/nvtx.py", line 116, in inner
result = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "fil.pyx", line 1258, in cuml.experimental.fil.fil.ForestInference.predict
File "fil.pyx", line 312, in cuml.experimental.fil.fil.ForestInference_impl.predict
File "fil.pyx", line 300, in cuml.experimental.fil.fil.ForestInference_impl._predict
RuntimeError: CUDA error encountered at: file=/raid/whicks/miniforge3/envs/triton_benchmark/include/raft/core/interruptible.hpp line=303:```
Describe the bug When run with
CUDA_VISIBLE_DEVICES=''
and within ausing_device_type('cpu')
block, experimental FIL throws a CUDA Error on anypredict
call. This makes CPU FIL effectively unusable without a GPU because the cuML CPU package does not currently include CPU FIL. I believe this is a regression from when CPU FIL was introduced due to an upstream change, but I am not certain. It can likely be fixed by callingsynchronize
at line 300 of fil.pyx only if the current device type is GPU.Error output below: