jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.45k stars 2.8k forks source link

CuSolver: Switch to 64 bit api to allow for eigh on matrices > than 26732x26732 #23413

Open PhilipVinc opened 2 months ago

PhilipVinc commented 2 months ago

Jaxlib links to CuSolver 32 bit api, which has a hard limit on workspace size which makes it such that it is not possible to diagonalise matrices larger than 26k^2 when using np.float64.

norabelrose commented 2 months ago

I'm running into a similar issue with jax.linalg.svd

PhilipVinc commented 2 months ago

Yeah, this affects all CuSolver apis, so svd and various factorizations as well…

dfm commented 2 months ago

Great suggestion! I'm in the midst of updating all the cuSolver wrappers so I'll plan on getting this in as part of that process. I'd guess that I probably won't be able to land this before the next JAX release, but I'll try!

For reference, it looks like there's an open issue suggesting this for the CPU backend too: https://github.com/google/jax/issues/20904

norabelrose commented 2 months ago

Thanks a lot for this @dfm, even just a PR / branch that uses the 64 bit api would be very useful since I am trying to run SVD on some large matrices for a project right now.

dfm commented 2 months ago

Sure - I can prioritize SVD. Just to confirm, you're running on a GPU, @norabelrose?

norabelrose commented 2 months ago

Sure - I can prioritize SVD. Just to confirm, you're running on a GPU, @norabelrose?

Yep, that's right. Thanks!