Open HanGuo97 opened 1 year ago
Hi - thanks for the question. Currently JAX implements QR decomposition via geqrf
, which doesn't support pivoting. For that, we would need geqp3
(which is what scipy uses for the pivoting=True
case).
Adding support for new lapack routines is somewhat involved, because JAX needs to be able to target appropriate implementations on all relevant backends (i.e. CPU, GPU, etc.). In this case, it appears that cuBLAS (where JAX gets its geqrf
implementation for nvidia GPU) does not currently have any implementation of geqp3
(see https://docs.nvidia.com/cuda/cublas/index.html)
With that in mind, I think it's unlikely that pivoting=True
will be supported in the near future, but I'll leave this issue open to track the request.
Understand; thanks for the explanation!
What hardware platform?
If you need it only on CPU it wouldn't be a huge lift to add it. Even on GPU it's possible (MAGMA has magma_?geqp3) we could support it. I'm not sure we'll have time to do it but maybe. And contributions are welcome.
Personally, it'd be great to have it on GPU (and CPU if that's easy, but GPU alone would be fine). Would this be easy?
Please:
Hi, I received the following error when executing
jax.scipy.linalg.qr(A, pivoting=True)
,I'm curious if there is a plan to implement Column-Pivoted QR Decomposition in the near future? Thanks in advance for the help!