google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.81k stars 2.72k forks source link

Column-Pivoted QR Decomposition #12897

Open HanGuo97 opened 1 year ago

HanGuo97 commented 1 year ago

Please:

Hi, I received the following error when executing jax.scipy.linalg.qr(A, pivoting=True),

NotImplementedError: The pivoting=True case of qr is not implemented.

I'm curious if there is a plan to implement Column-Pivoted QR Decomposition in the near future? Thanks in advance for the help!

jakevdp commented 1 year ago

Hi - thanks for the question. Currently JAX implements QR decomposition via geqrf, which doesn't support pivoting. For that, we would need geqp3 (which is what scipy uses for the pivoting=True case).

Adding support for new lapack routines is somewhat involved, because JAX needs to be able to target appropriate implementations on all relevant backends (i.e. CPU, GPU, etc.). In this case, it appears that cuBLAS (where JAX gets its geqrf implementation for nvidia GPU) does not currently have any implementation of geqp3 (see https://docs.nvidia.com/cuda/cublas/index.html)

With that in mind, I think it's unlikely that pivoting=True will be supported in the near future, but I'll leave this issue open to track the request.

HanGuo97 commented 1 year ago

Understand; thanks for the explanation!

hawkinsp commented 1 year ago

What hardware platform?

If you need it only on CPU it wouldn't be a huge lift to add it. Even on GPU it's possible (MAGMA has magma_?geqp3) we could support it. I'm not sure we'll have time to do it but maybe. And contributions are welcome.

HanGuo97 commented 1 year ago

Personally, it'd be great to have it on GPU (and CPU if that's easy, but GPU alone would be fine). Would this be easy?