google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.93k stars 2.74k forks source link

Improve CuSolver errors diagnostics #23410

Open PhilipVinc opened 1 week ago

PhilipVinc commented 1 week ago

TLDR:

Often when writing scientific algorithms we have to use some routines from cuSolver, like svd/eigh/qr. Those routines sometimes fail with unclear error messages that are not easy to understand.

Often, a reason is not enough memory for their workspace, but that's not part of the message (even if a priori this could be reported). I would like this to be reported.

Also: Jax is using 32 bit CuSolver API, which might explain why larger workspaces cannot be created. Would it be possible to use a 64 bit API?

Longer story:

On some code we have we are now hitting the following CuSolver error during tracing/compilation which arises exactly during tracing of jnp.linalg.eigh of a ~16k x 16k~ 32k x 32k jnp.float64 matrix residing on an A100-80G gpu (it works with a ~12k x 12k~ 24k x 24k matrix).

Traceback (most recent call last):
  File "/lustre/fswork/projects/rech/iqu/uvm91ap/repos/netket_pro/deepnets/optimization/run.py", line 139, in <module>
    sim_time, n_parameters = opt.standard(
...
  File "/linkhome/rech/gencpt01/uvm91ap/.conda/envs/nk_gpu_mpi_amd/lib/python3.11/site-packages/jax/_src/numpy/linalg.py", line 838, in eigh
    v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: jaxlib/gpu/solver.cc:213: operation gpusolverDnDsyevd_bufferSize( handle.get(), jobz, uplo, n, nullptr, n, nullptr, &lwork) failed: cuSolver invalid value error

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
...
  File "/linkhome/rech/gencpt01/uvm91ap/.conda/envs/nk_gpu_mpi_amd/lib/python3.11/site-packages/jaxlib/gpu_solver.py", line 321, in _syevd_hlo
    lwork, opaque = gpu_solver.build_syevd_descriptor(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/gpu/solver.cc:213: operation gpusolverDnDsyevd_bufferSize( handle.get(), jobz, uplo, n, nullptr, n, nullptr, &lwork) failed: cuSolver invalid value error

We are 99.9% sure this is a memory error, because reducing the size of matrix makes the problem go away.

However the message is not clear, and looking at cuSolver documentation for sieved_bufferSize it suggests that invalid value error should arise in the following cases, which I do not understand how they could be related to an OOM error.

Screenshot 2024-09-03 at 20 43 18

Digging inside Jax, I see that the error comes from jax.lib.gpu_solver._syevd_hlo which calls into the C function solver::BuildSyevdDescriptor, which calls cuda as

      JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevd_bufferSize(
          handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
          &lwork)));

and simply returns the cusolver workspace size (and a pointer I think).

So I wondered if maybe to diagonalise my 16'000^2 matrix the workspace was really so large? So I called this function manually

import jax; import jaxlib; import numpy as np
# call convention is dtype, lower, n_batches, matrix linear size
>>> jaxlib.gpu_solver._cusolver.build_syevd_descriptor(np.dtype(np.float64), False, 1, 16000)
(770064577, b'\x01\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x80>\x00\x00\xc1@\xe6-')

which states that the workspace size should be about 700 MB, which does not seem so large and since XLA preallocates only 75% of memory, this should not be a problem...

I also verified that calling the same function with a much larger matrix size won't work

>>> jaxlib.gpu_solver._cusolver.build_syevd_descriptor(np.dtype(np.float64), False, 1, 27000)
RuntimeError: jaxlib/gpu/solver.cc:213: operation gpusolverDnDsyevd_bufferSize( handle.get(), jobz, uplo, n, nullptr, n, nullptr, &lwork) failed: cuSolver invalid value error

It seems the largest size I can make this work for with a fresh instance of jax is ~26500, corresponding to 2031344577 which is just short of 2GB of memory. However jax should should have left available 20GB of memory on the gpu. Why is that so?

FYI, I tried to disable preallocation and my code still does not work...

PhilipVinc commented 1 week ago

Maybe jax is using the 32 bit CuSolver API and that's what is limiting us? Maybe we should use the 64 bit API?

Looking at the cuda documentation, Dsyevd_bufferSizecorresponds to the legacy 32 bit API, so I suspect jaxlib is using that one? This might explain why larger buffer sizes cannot be created.... From this thread https://forums.developer.nvidia.com/t/memory-alloc-error-for-cholmod-factorization-in-cusolver-library/76561/6 I see that the 32 bit API is limited to temporary workspaces of size 2^32 single-precision, or 2^31 double precision numbers...

Nevertheless, this does not explain why the function is crashing on my particular use-case, because the matrix is smaller than that.

dfm commented 1 week ago

Thanks for the report and for digging into it so deeply!

For the request that we support the 64-bit API: great idea! We can definitely do that and I'll comment some more over on #23413.

With respect to the workspace allocation errors, I think things should be a little bit better as we update all the solvers to use the new FFI since we can be more explicit about when OOM errors occur (see here, for example). But (like you) I'm still a bit confused about where the issue is coming from in this specific case. To try to narrow down the problem, can you see if you hit this error when lowering (e.g. jax.jit(jnp.linalg.eigh).lower(a)) because that's when the descriptor is built. If it doesn't happen there, but does when you run jax.jit(jnp.linalg.eigh)(a), then we're getting the wrong traceback and it would be useful to know that.

PhilipVinc commented 1 week ago

Yes, I can reproduce by running only that

import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", val)

N=32000
a=jnp.ones((N,N), dtype=jnp.float64)
jax.jit(jnp.linalg.eigh).lower(a)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/lustre/fswork/projects/rech/iqu/udb21rp/test_rajah/.venv/lib/python3.12/site-packages/jax/_src/numpy/linalg.py", line 838, in eigh
    v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: jaxlib/gpu/solver.cc:213: operation gpusolverDnDsyevd_bufferSize( handle.get(), jobz, uplo, n, nullptr, n, nullptr, &lwork) failed: cuSolver invalid value error

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/lustre/fswork/projects/rech/iqu/udb21rp/test_rajah/.venv/lib/python3.12/site-packages/jaxlib/gpu_solver.py", line 308, in _syevd_hlo
    lwork, opaque = gpu_solver.build_syevd_descriptor(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/gpu/solver.cc:213: operation gpusolverDnDsyevd_bufferSize( handle.get(), jobz, uplo, n, nullptr, n, nullptr, &lwork) failed: cuSolver invalid value error
--------------------

Sorry for the above being a bit confusing. It was like a stream of consciousness while I was digging...

The reason I have opened #23413 is exactly because I found out that the problem is the workspace size being limited to ~ 2^30 bytes in the 32 bit workspace API.

I found some issues on other libraries around GitHub, and this seems like a know problem. See for example https://github.com/rapidsai/cuml/issues/2597 . From CuSolver API documentation is seems the 32 bit api is also deprecated.

I'm pretty confident that upgrading the api should fix this issue.

PhilipVinc commented 1 week ago

As for upgrading the error message raised, I understand that it's not trivial. Cuda does not really give you a good reason for failing in this case, so I don't see an alternative from hardcoding an error message if the matrix size is above a certain size?

Or maybe you could catch the error and add a comment saying that probably it's because of 32 bit apis and the solution could be to either use lower precision, smaller matrix, or another algorithm.

I believe this could work, because I see no way that the 'standard' causes of cuSolver invalid value error can arise. The jax code already declares all arguments correctly.

dfm commented 1 week ago

Yes, I can reproduce by running only that

But I thought you said you were also seeing the same error when N = 16_000? Is that no longer true?

PhilipVinc commented 1 week ago

Ah. Eh. Sorry. I should update the description above.

Turns out I had complex numbers originally, so I was using a matrix twice as large as what I thought...

The problem arises around N=26732.

PhilipVinc commented 1 week ago

Plus, I think if you actually switch to the 64 bit api throwing better errors won't be needed anymore, because this should never happen...

dfm commented 1 week ago

OK, that makes more sense! Yeah, I was confused about why that error was showing up for N=16k and float64 dtype.

I'm on board for updating to the 64-bit API!