scikit-learn / scikit-learn

scikit-learn: machine learning in Python
https://scikit-learn.org
BSD 3-Clause "New" or "Revised" License
59.64k stars 25.31k forks source link

Use SYRK instead of GEMM in pairwise distance #27894

Closed darshanp4 closed 9 months ago

darshanp4 commented 10 months ago

Discussed in https://github.com/scikit-learn/scikit-learn/discussions/27877

Originally posted by **darshanp4** November 30, 2023 Hello I was checking the DBSCAN algo , where mostly computing pairwise distance it use -2X*X**T, so for this operation currently sklearn uses the blas _gemm. where it consuming much time. If we can use the syrk which also level3 blas (https://pyclblas.readthedocs.io/en/latest/SYRK.html). Which is more optimized for this type of operations. Is there any similar approach any one seen or tried. Thank you.
glemaitre commented 10 months ago

Pinging people that might know the answer straight away: @jjerphan @ogrisel @fcharras @Micky774?

glemaitre commented 10 months ago

I was looking at the following that could be interested as well: https://www.osti.gov/servlets/purl/1557469

But it seems that SYRK should skip some multiplication that GEMM is performing.

jjerphan commented 10 months ago

SYRK must in theory be more appropriate than GEMM in this case (i.e. when $A = B$), but only comparing implementations will tell us.

darshanp4 commented 10 months ago

@glemaitre thanks for adding more folks, also I will take a look at paper : https://www.osti.gov/servlets/purl/1557469.

@jjerphan yes in theory it is more appropriate, i don't have comparison. I will try see by adding it. do you have any reference , how can we add blas functions.

jjerphan commented 10 months ago

You can easily add them from SciPy.

See:

https://github.com/scikit-learn/scikit-learn/blob/2d9f4c442030450cd6246765df2a546e6440b8ac/sklearn/utils/_cython_blas.pxd#L38

https://github.com/scikit-learn/scikit-learn/blob/2d9f4c442030450cd6246765df2a546e6440b8ac/sklearn/utils/_cython_blas.pyx#L184-L186

darshanp4 commented 10 months ago

thank you!

ogrisel commented 10 months ago

Looking forward to a PR with some quick benchmarks.

lorentzenchr commented 10 months ago

For information, numpy uses syrk under the hood if it detects X.T @ X.

jjerphan commented 10 months ago

Also, optimizations will work for full matrices, for pattern relying on chunks (à la PairwiseDistancesReduction) we can only benefit from SYRK on the diagonal chunks.

darshanp4 commented 9 months ago

@jjerphan so for PairwiseDistancesReduction, if it is already using the chunks of 256 size, how we can identify the diagonal chunks? X_start == Y_start and X_end == Y_end Any thoughts on how we can make it work for full matrices!

darshanp4 commented 9 months ago

Also can you help with understanding why at scikit-learn we are relying on chunks, as in backend OpenBLAS is also doing it!

jjerphan commented 9 months ago

X is Y and X_start == Y_start and X_end == Y_end should be sufficient to identify when to use SYRK, but this might add too much complexity and it might cause regression due to branching and branch mis-prediction for every n_chunks + 1 iterations here:

https://github.com/scikit-learn/scikit-learn/blob/3b06962d280f8776e3e94c7aed862081a82a9cc6/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp#L445-L446

You can try and perform benchmarks with https://github.com/scikit-learn/pairwise-distances-reductions-asv-suite, but a priori using SYRK in PairwiseDistancesReduction would not help significantly.

Chunks are used for PairwiseDistancesReduction as explained in the private submodule documentation, here:

https://github.com/scikit-learn/scikit-learn/blob/3b06962d280f8776e3e94c7aed862081a82a9cc6/sklearn/metrics/_pairwise_distances_reduction/__init__.py#L8-L29

Let us know if something is unclear.

darshanp4 commented 9 months ago

To keep intermediate data-structures in CPU cache, cache size matters for performance. which will be different for every ISA. But here it is hardcoded to 256. In OpenBLAS it is handled for each ISA. So shouldn't we leave it for OpenBLAS to handle it more optimize way.

lorentzenchr commented 9 months ago

The point is that the pairwise distances performe reductions (aggregation functions) over a chunk of distance values. This chunk size was empirically tested, and 256 chosen as a good middle ground. This is much different from BLAS functions, e.g. matrix-matrix operations is just multiplications and additions, nothing else. Therefore, the cache is even more important and, I guess, the most engineering-hours of work put into optimizing it - out of all algos!

Back to this issue: As SYRK would only apply to the diagonal, it is a 2. order effect. I'm -1 on it considering the trade-offs with code complexity.