Open jeremiedbb opened 5 months ago
This is indeed interesting and might also be of interest for other people in the ecosystem, e.g. @rgommers and numpy/scipy developers interested in multithreading.
Wow, this is amazing! From looking through, https://github.com/OpenMathLib/OpenBLAS/pull/4577, I think we'll need to write and register a callback that hooks up scikit-learn's vendered OpenMP with OpenBLAS.
If this backend-specific callback code is useful for other projects, is there a way to share it through threadpoolctl
? Concretely, something like threadpoolctl.register_openblas_backend("openmp")
. Although, this does increase the scope of threadpoolctl
, the feature feels related.
I am not sure we can do that as part of threadpoolctl
itself or more precisely, if we do it as part of threadpoolctl, then openblas will use the openmp runtime linked against a native extension shipped with threadpoolctl
but we would have no guarantee that this is the same runtime as the one linked to sklearn
's Cython native extensions with prange
loops.
if we do it as part of threadpoolctl, then openblas will use the openmp runtime linked against a native extension shipped with threadpoolctl
I do not think threadpoolctl
should ship any runtimes.
I was thinking of using threadpoolctl
to connect the openmp runtime loaded by sklearn
with OpenBLAS
. Concretely, Scikit-learn is responsible for calling threadpoolctl.register_openblas_backend("openmp")
. If we want to be more verbose:
PATH_TO_LIBOMP = ".../lib/python3.12/site-packages/sklearn/.dylibs/libomp.dylib"
threadpoolctl.register_openblas_backend(PATH_TO_LIBOMP)
OpenBLAS v0.3.28 will have a new feature allowing OpenBLAS to use the threadpool chosen by the user, (see https://github.com/OpenMathLib/OpenBLAS/pull/4577).
This is very interesting because it would solve a performance issue happening when there's a quick succession of BLAS calls and OpenMP (prange) calls. The issue happens when OpenBLAS and OpenMP don't share the same threadpool because both threadpools are in active wait mode when they're idle (see https://github.com/OpenMathLib/OpenBLAS/issues/3187 for details), which is a current situation since numpy and scipy wheels are built against OpenBLAS with the pthreads threading layer.
This issue is currently impacting some estimators like KMeans (https://github.com/scikit-learn/scikit-learn/issues/20642), NMF (https://github.com/scikit-learn/scikit-learn/pull/16439), pairwise_distances (https://github.com/scikit-learn/scikit-learn/issues/26097), ...
Being able to configure OpenBLAS to use our OpenMP threadpool would allow to get rid of this issue even if numpy and scipy keep building their wheels against OpenBLAS pthreads (which is very likely).
I'm not sure yet if or how https://github.com/OpenMathLib/OpenBLAS/pull/4577 would make this possible so I'm opening this issue to track the progress on this subject.