Open ma-gilles opened 10 months ago
Hmm. So JAX and Lineax both basically do the same thing for the LU/QR/Cholesky solvers, which is to use the JAX (and thus probably CUDA) implementation of those decompositions.
The fact that the QR solve is slow is expected I think -- IIRC there's no CUDA implementation of a batched QR decomposition, so vmap is handled by computing the decomposition for each batch element sequentially.
I suspect the issue is probably somewhere in the underlying CUDA (cuSolver?) implementations. I think resolving this will probably need someone to go digging through things at that level, I'm afraid.
Hi Patrick,
Thank for your answer!
I can't say I really understand how JAX/torch/cupy interact with CUDA code, but what is surprising to me is that this seems to be a bug only in JAX. Both torch/cupy seem to work, even though I would assume they use the same backend.
E.g.:
import numpy as np
import torch
n = int(1e7); m = 10
A = torch.tensor(np.repeat(np.identity(m)[None], n, axis = 0))
L = torch.linalg.cholesky(A)
print(torch.linalg.norm(A - L))
Outputs:
tensor(0., dtype=torch.float64)
And the same thing for cupy, but JAX returns NaNs.
Oh interesting! Hmm, in that case I'm less certain of the reason. Maybe check that it's not a version issue? PyTorch and JAX tend to use different versions of the underlying NVIDIA libraries.
Thanks for the suggestion! I tried a few different versions of CUDA without changes, but updating jax seems to fix the problem, or at least it passes the few tests I have tried.
Curious! Well, I'm glad it's fixed. :) Possibly an issue with a particular version of jaxlib then, if updating the version fixed things.
Hello,
I opened a similar issue on the main JAX (https://github.com/google/jax/issues/19431) but I thought it may get more attention here.
The batched JAX linear solves seem to be bugged for large batches on GPU, even if it can still comfortably fit in GPU memory. In short, if you try to solve a bunch of linear system, then the JAX LU/Cholesky solver will sometime return NaN's/other problems but not throw an error or warning. The SVD-based solve seems to work better, though it also fails if you get close enough to filling the full GPU memory. The QR-based solve is too slow for me to test at large batch size, strangely. The lineax solves has the same behavior, although it does throw an error upon seeing NaNs.
Below is a test and output, where solving Ax = b where A is the identity and b is all ones returns NaNs. I am curious if someone can reproduce this behavior and has any ideas on what to do.
Thank you for making this nice library! Best, Marc
Output: