Open PhilipVinc opened 2 months ago
I'm running into a similar issue with jax.linalg.svd
Yeah, this affects all CuSolver apis, so svd and various factorizations as well…
Great suggestion! I'm in the midst of updating all the cuSolver wrappers so I'll plan on getting this in as part of that process. I'd guess that I probably won't be able to land this before the next JAX release, but I'll try!
For reference, it looks like there's an open issue suggesting this for the CPU backend too: https://github.com/google/jax/issues/20904
Thanks a lot for this @dfm, even just a PR / branch that uses the 64 bit api would be very useful since I am trying to run SVD on some large matrices for a project right now.
Sure - I can prioritize SVD. Just to confirm, you're running on a GPU, @norabelrose?
Sure - I can prioritize SVD. Just to confirm, you're running on a GPU, @norabelrose?
Yep, that's right. Thanks!
Jaxlib links to CuSolver 32 bit api, which has a hard limit on workspace size which makes it such that it is not possible to diagonalise matrices larger than 26k^2 when using np.float64.