patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Apache License 2.0
365 stars 24 forks source link

Batched JAX linear solves bugged for large batches #79

Open ma-gilles opened 10 months ago

ma-gilles commented 10 months ago

Hello,

I opened a similar issue on the main JAX (https://github.com/google/jax/issues/19431) but I thought it may get more attention here.

The batched JAX linear solves seem to be bugged for large batches on GPU, even if it can still comfortably fit in GPU memory. In short, if you try to solve a bunch of linear system, then the JAX LU/Cholesky solver will sometime return NaN's/other problems but not throw an error or warning. The SVD-based solve seems to work better, though it also fails if you get close enough to filling the full GPU memory. The QR-based solve is too slow for me to test at large batch size, strangely. The lineax solves has the same behavior, although it does throw an error upon seeing NaNs.

Below is a test and output, where solving Ax = b where A is the identity and b is all ones returns NaNs. I am curious if someone can reproduce this behavior and has any ideas on what to do.

Thank you for making this nice library! Best, Marc

import jax
import lineax as lx

import jax.numpy as jnp
from jax import random
device = jax.local_devices()[0]
print('on device:', device)

m = 10

batched_solve_lu = jax.vmap( lambda matrix, vector: lx.linear_solve(lx.MatrixLinearOperator(matrix), vector, solver=lx.LU()).value)
batched_solve_SVD = jax.vmap( lambda matrix, vector: lx.linear_solve(lx.MatrixLinearOperator(matrix), vector, solver=lx.SVD()).value)

solve_fns = [jax.scipy.linalg.solve, batched_solve_SVD, batched_solve_lu]

for solve_fn in solve_fns:
    for n in [ int(1e6), int(1e7)]:
        A = jnp.repeat(jnp.identity(m)[None], n, axis = 0)

        x = jnp.ones([n,m])
        b = jax.lax.batch_matmul(A,x[...,None])[...,0]

        x_solved = solve_fn(A,b)
        print(f"Average error with n ={n}, {jnp.mean(jnp.linalg.norm(x - x_solved, axis=-1))} ")
        print("Memory info ", device.memory_stats())

Output:

on device: cuda:0
Average error with n =1000000, 0.0 
Memory info  {'bytes_in_use': 520001024, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 804000512, 'largest_free_block_bytes': 0, 'num_allocs': 29, 'peak_bytes_in_use': 1324001536, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
Average error with n =10000000, nan 
Memory info  {'bytes_in_use': 5200002048, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 8040000512, 'largest_free_block_bytes': 0, 'num_allocs': 58, 'peak_bytes_in_use': 13280002560, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
Average error with n =1000000, 0.0 
Memory info  {'bytes_in_use': 520000000, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 8040000512, 'largest_free_block_bytes': 0, 'num_allocs': 98, 'peak_bytes_in_use': 13280002560, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
Average error with n =10000000, 0.0 
Memory info  {'bytes_in_use': 5200000000, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 8040000512, 'largest_free_block_bytes': 0, 'num_allocs': 136, 'peak_bytes_in_use': 18120001024, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
Average error with n =1000000, 0.0 
Memory info  {'bytes_in_use': 560002560, 'bytes_limit': 63880937472, 'bytes_reserved': 0, 'largest_alloc_size': 8040000512, 'largest_free_block_bytes': 0, 'num_allocs': 174, 'peak_bytes_in_use': 18120001024, 'peak_bytes_reserved': 0, 'peak_pool_bytes': 63880937472, 'pool_bytes': 63880937472}
2024-01-24 15:17:43.130711: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: CpuCallback error: EqxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.

If you *were* expecting this solver to work with this operator, then it may be because:

(a) the operator is singular, and your code has a bug; or

(b) the operator was nearly singular (i.e. it had a high condition number:
    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
    numerical instability issues; or

(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
    that is does not actually satisfy.

At:
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_errors.py(70): raises
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(258): _flat_callback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(52): pure_callback_impl
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(188): _callback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py(2267): _wrapped_callback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py(1152): __call__
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/profiler.py(314): wrapper
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1151): _pjit_call_impl_python
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1195): call_impl_cache_miss
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1211): _pjit_call_impl
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(869): process_primitive
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(389): bind_with_trace
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(2657): bind
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1427): _pjit_batcher
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/batching.py(433): process_primitive
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(389): bind_with_trace
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(2657): bind
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(166): _python_pjit_helper
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(255): cache_miss
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/traceback_util.py(177): reraise_with_filtered_traceback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_jit.py(200): _call
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_module.py(935): __call__
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_jit.py(206): __call__
  /tmp/ipykernel_3270959/20536111.py(9): <lambda>
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/linear_util.py(190): call_wrapped
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/api.py(1260): vmap_f
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/traceback_util.py(177): reraise_with_filtered_traceback
  /tmp/ipykernel_3270959/20536111.py(22): <module>
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3526): run_code
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3466): run_ast_nodes
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3284): run_cell_async
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3079): _run_cell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3024): run_cell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/zmqshell.py(546): run_cell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/ipkernel.py(422): do_execute
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(740): execute_request
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(412): dispatch_shell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(505): process_one
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(516): dispatch_queue
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/events.py(80): _run
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/base_events.py(1905): _run_once
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/base_events.py(601): run_forever
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/tornado/platform/asyncio.py(195): start
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelapp.py(736): start
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/traitlets/config/application.py(1053): launch_instance
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel_launcher.py(17): <module>
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/runpy.py(87): _run_code
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/runpy.py(197): _run_module_as_main

2024-01-24 15:17:43.130798: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2711] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: CpuCallback error: EqxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.

If you *were* expecting this solver to work with this operator, then it may be because:

(a) the operator is singular, and your code has a bug; or

(b) the operator was nearly singular (i.e. it had a high condition number:
    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
    numerical instability issues; or

(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
    that is does not actually satisfy.

At:
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_errors.py(70): raises
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(258): _flat_callback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(52): pure_callback_impl
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/callback.py(188): _callback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py(2267): _wrapped_callback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py(1152): __call__
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/profiler.py(314): wrapper
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1151): _pjit_call_impl_python
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1195): call_impl_cache_miss
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1211): _pjit_call_impl
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(869): process_primitive
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(389): bind_with_trace
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(2657): bind
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(1427): _pjit_batcher
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/batching.py(433): process_primitive
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(389): bind_with_trace
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/core.py(2657): bind
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(166): _python_pjit_helper
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/pjit.py(255): cache_miss
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/traceback_util.py(177): reraise_with_filtered_traceback
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_jit.py(200): _call
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_module.py(935): __call__
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/equinox/_jit.py(206): __call__
  /tmp/ipykernel_3270959/20536111.py(9): <lambda>
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/linear_util.py(190): call_wrapped
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/api.py(1260): vmap_f
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/traceback_util.py(177): reraise_with_filtered_traceback
  /tmp/ipykernel_3270959/20536111.py(22): <module>
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3526): run_code
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3466): run_ast_nodes
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3284): run_cell_async
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3079): _run_cell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3024): run_cell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/zmqshell.py(546): run_cell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/ipkernel.py(422): do_execute
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(740): execute_request
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(412): dispatch_shell
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(505): process_one
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelbase.py(516): dispatch_queue
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/events.py(80): _run
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/base_events.py(1905): _run_once
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/asyncio/base_events.py(601): run_forever
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/tornado/platform/asyncio.py(195): start
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel/kernelapp.py(736): start
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/traitlets/config/application.py(1053): launch_instance
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/site-packages/ipykernel_launcher.py(17): <module>
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/runpy.py(87): _run_code
  /home/mg6942/.conda/envs/recovar2/lib/python3.9/runpy.py(197): _run_module_as_main
; current tracing scope: custom-call.101; current profiling annotation: XlaModule:#hlo_module=jit_linear_solve,program_id=40#.
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[1], line 22
     19 x = jnp.ones([n,m])
     20 b = jax.lax.batch_matmul(A,x[...,None])[...,0]
---> 22 x_solved = solve_fn(A,b)
     23 print(f"Average error with n ={n}, {jnp.mean(jnp.linalg.norm(x - x_solved, axis=-1))} ")
     24 print("Memory info ", device.memory_stats())

    [... skipping hidden 3 frame]

Cell In[1], line 9, in <lambda>(matrix, vector)
      5 print('on device:', device)
      7 m = 10
----> 9 batched_solve_lu = jax.vmap( lambda matrix, vector: lx.linear_solve(lx.MatrixLinearOperator(matrix), vector, solver=lx.LU()).value)
     10 batched_solve_SVD = jax.vmap( lambda matrix, vector: lx.linear_solve(lx.MatrixLinearOperator(matrix), vector, solver=lx.SVD()).value)
     12 solve_fns = [jax.scipy.linalg.solve, batched_solve_SVD, batched_solve_lu]

    [... skipping hidden 14 frame]

File ~/.conda/envs/recovar2/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py:1152, in ExecuteReplicated.__call__(self, *args)
   1150   self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
   1151 else:
-> 1152   results = self.xla_executable.execute_sharded(input_bufs)
   1153 if dispatch.needs_check_special():
   1154   out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.

If you *were* expecting this solver to work with this operator, then it may be because:

(a) the operator is singular, and your code has a bug; or

(b) the operator was nearly singular (i.e. it had a high condition number:
    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
    numerical instability issues; or

(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
    that is does not actually satisfy.
-------
This error occurred during the runtime of your JAX program. Setting the environment
variable `EQX_ON_ERROR=breakpoint` is usually the most useful way to debug such errors.
(This can be navigated using most of the usual commands for the Python debugger:
`u` and `d` to move through stack frames, the name of a variable to print its value,
etc.) See also `[https://docs.kidger.site/equinox/api/errors/#equinox.error_if`](https://docs.kidger.site/equinox/api/errors/#equinox.error_if%60) for more
information.
patrick-kidger commented 9 months ago

Hmm. So JAX and Lineax both basically do the same thing for the LU/QR/Cholesky solvers, which is to use the JAX (and thus probably CUDA) implementation of those decompositions.

The fact that the QR solve is slow is expected I think -- IIRC there's no CUDA implementation of a batched QR decomposition, so vmap is handled by computing the decomposition for each batch element sequentially.

I suspect the issue is probably somewhere in the underlying CUDA (cuSolver?) implementations. I think resolving this will probably need someone to go digging through things at that level, I'm afraid.

ma-gilles commented 9 months ago

Hi Patrick,

Thank for your answer!

I can't say I really understand how JAX/torch/cupy interact with CUDA code, but what is surprising to me is that this seems to be a bug only in JAX. Both torch/cupy seem to work, even though I would assume they use the same backend.

E.g.:

  import numpy as np
  import torch
  n = int(1e7); m = 10
  A = torch.tensor(np.repeat(np.identity(m)[None], n, axis = 0))
  L = torch.linalg.cholesky(A)
  print(torch.linalg.norm(A - L))

Outputs:

  tensor(0., dtype=torch.float64)

And the same thing for cupy, but JAX returns NaNs.

patrick-kidger commented 9 months ago

Oh interesting! Hmm, in that case I'm less certain of the reason. Maybe check that it's not a version issue? PyTorch and JAX tend to use different versions of the underlying NVIDIA libraries.

ma-gilles commented 9 months ago

Thanks for the suggestion! I tried a few different versions of CUDA without changes, but updating jax seems to fix the problem, or at least it passes the few tests I have tried.

patrick-kidger commented 9 months ago

Curious! Well, I'm glad it's fixed. :) Possibly an issue with a particular version of jaxlib then, if updating the version fixed things.