intel / scikit-learn-intelex

Intel(R) Extension for Scikit-learn is a seamless way to speed up your Scikit-learn application
https://intel.github.io/scikit-learn-intelex/
Apache License 2.0
1.23k stars 175 forks source link

Bug/unclear behavior in _assert_all_finite #2165

Open AdamosX opened 2 days ago

AdamosX commented 2 days ago

Describe the bug When running RandomForestClassifier I encountered the following text in the logs: sklearn.utils.validation._assert_all_finite: patching failed with cause - X dtype is not float32 or float64. This happens only if the number of samples is sufficiently big. I don't know how this relates to the optimizations - are they enabled or not.

The function _assert_all_finite in daal4py.sklearn.utils.validation contains the following code:

...
    if hasattr(X, "size"):
        if X.size < 32768:
            if sklearn_check_version("1.1"):
                _sklearn_assert_all_finite(
                    X,
                    allow_nan=allow_nan,
                    msg_dtype=msg_dtype,
                    estimator_name=estimator_name,
                    input_name=input_name,
                )
            else:
                _sklearn_assert_all_finite(X, allow_nan=allow_nan, msg_dtype=msg_dtype)
            return
...

To Reproduce Run the following snippet:

from sklearnex import patch_sklearn
patch_sklearn()
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=33000, n_features=4, n_informative=2, n_redundant=0, random_state=0, shuffle=False)
clf = RandomForestClassifier(max_depth=3, n_jobs=-1, n_estimators=100)
clf.fit(X, y)

Set n_samples = 100 or n_samples=33000

Expected behavior I haven't managed to understand if this check really influences the optimizations, the logging is wrong or I misunderstood the meaning. EDIT: Now I think it only disables the optimizations for _assert_all_finite and the random forest classifier itself, works as expected.

Output/Screenshots

For big n_samples the following is in the logs:

INFO:sklearnex: sklearn.utils.validation._assert_all_finite: running accelerated version on CPU
DEBUG:sklearnex: sklearn.utils.validation._assert_all_finite: debugging for the patch is enabled to track the usage of Intel® oneAPI Data Analytics Library (oneDAL)
DEBUG:sklearnex: sklearn.utils.validation._assert_all_finite: patching failed with cause - X dtype is not float32 or float64.
INFO:sklearnex: sklearn.utils.validation._assert_all_finite: fallback to original Scikit-learn
INFO:sklearnex: sklearn.ensemble.RandomForestClassifier.fit: running accelerated version on CPU
DEBUG:sklearnex: sklearn.utils.validation._assert_all_finite: debugging for the patch is enabled to track the usage of Intel® oneAPI Data Analytics Library (oneDAL)
DEBUG:sklearnex: sklearn.utils.validation._assert_all_finite: patching failed with cause - X dtype is not float32 or float64.
INFO:sklearnex: sklearn.utils.validation._assert_all_finite: fallback to original Scikit-learn
INFO:sklearnex: sklearn.utils.validation._assert_all_finite: running accelerated version on CPU
DEBUG:sklearnex: sklearn.utils.validation._assert_all_finite: debugging for the patch is enabled to track the usage of Intel® oneAPI Data Analytics Library (oneDAL)
DEBUG:sklearnex: sklearn.utils.validation._assert_all_finite: patching failed with cause - X dtype is not float32 or float64.
INFO:sklearnex: sklearn.utils.validation._assert_all_finite: fallback to original Scikit-learn

Whereas for small n_samples:

INFO:sklearnex: sklearn.ensemble.RandomForestClassifier.fit: running accelerated version on CPU

Environment:

icfaust commented 2 days ago

Hello @AdamosX , the finite checker acceleration can only work with input data which is float32 or float64, so it is falling back to using the sklearn finite checker. The y values which come out of make_classification are integer values, which are finite by definition (and won't be finite checked at all) meaning that there is no difference in the overall speed. In this case, the log was a bit of a red herring.

icfaust commented 2 days ago

I will try and correct that it is using X in the log messages, it is hard-coded to that and is a misnomer, it should be using input_name