patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Apache License 2.0
365 stars 24 forks source link

Raise error while throw=False #77

Closed mariogeiger closed 10 months ago

mariogeiger commented 10 months ago

I hope I will find my issue and close this issue myself but I post already the issue.

I use dx.diffeqsolve and lx.linear_solve to integrate a geodesic equation. In both function I set throw=False and I check solution.result to handle the error on my side. However it still raise errors as shown below.

More details: I use dx.ODETerm, dx.Dopri5, dx.PIDController, lx.JacobianLinearOperator, lx.AutoLinearSolver(well_posed=False)

2024-01-11 18:10:46.719440: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: CpuCallback error: EqxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.

If you *were* expecting this solver to work with this operator, then it may be because:

(a) the operator is singular, and your code has a bug; or

(b) the operator was nearly singular (i.e. it had a high condition number:
    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
    numerical instability issues; or

(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
    that is does not actually satisfy.

At:
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/equinox/_errors.py(70): raises
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/callback.py(262): _flat_callback
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/callback.py(53): pure_callback_impl
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/callback.py(192): _callback
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py(2367): _wrapped_callback
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py(1151): __call__
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/profiler.py(336): wrapper
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py(2806): call
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/stages.py(538): call
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/stages.py(566): cpp_call_fallback
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/stages.py(569): __call__
  /home/mgeiger/git/test/train_regression.py(241): train
  /home/mgeiger/git/test/train_regression.py(274): <module>

2024-01-11 18:10:46.719782: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: CpuCallback error: EqxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.

If you *were* expecting this solver to work with this operator, then it may be because:

(a) the operator is singular, and your code has a bug; or

(b) the operator was nearly singular (i.e. it had a high condition number:
    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
    numerical instability issues; or

(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
    that is does not actually satisfy.

At:
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/equinox/_errors.py(70): raises
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/callback.py(262): _flat_callback
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/callback.py(53): pure_callback_impl
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/callback.py(192): _callback
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py(2367): _wrapped_callback
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py(1151): __call__
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/profiler.py(336): wrapper
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py(2806): call
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/stages.py(538): call
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/stages.py(566): cpp_call_fallback
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/stages.py(569): __call__
  /home/mgeiger/git/test/train_regression.py(241): train
  /home/mgeiger/git/test/train_regression.py(274): <module>
; current tracing scope: custom-call.54; current profiling annotation: XlaModule:#hlo_module=jit_update,program_id=773#.
Traceback (most recent call last):
  File "/home/mgeiger/git/test/train_regression.py", line 274, in <module>
    with open(logfile, "at") as fhandle:
^^^^^^^
  File "/home/mgeiger/git/test/train_regression.py", line 241, in train

  File "/home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/stages.py", line 569, in __call__
    return self._call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/stages.py", line 566, in cpp_call_fallback
    outs, _, _ = Compiled.call(params, *args, **kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/stages.py", line 538, in call
    out_flat = params.executable.call(*args_flat)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2806, in call
    return self.unsafe_call(*args)  # pylint: disable=not-callable
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1151, in __call__
    results = self.xla_executable.execute_sharded(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: CpuCallback error: EqxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.

If you *were* expecting this solver to work with this operator, then it may be because:

(a) the operator is singular, and your code has a bug; or

(b) the operator was nearly singular (i.e. it had a high condition number:
    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
    numerical instability issues; or

(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
    that is does not actually satisfy.

At:
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/equinox/_errors.py(70): raises
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/callback.py(262): _flat_callback
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/callback.py(53): pure_callback_impl
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/callback.py(192): _callback
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py(2367): _wrapped_callback
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py(1151): __call__
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/profiler.py(336): wrapper
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py(2806): call
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/stages.py(538): call
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/stages.py(566): cpp_call_fallback
  /home/mgeiger/git/test/.venv/lib/python3.11/site-packages/jax/_src/stages.py(569): __call__
  /home/mgeiger/git/test/train_regression.py(241): train
  /home/mgeiger/git/test/train_regression.py(274): <module>
; current tracing scope: custom-call.54; current profiling annotation: XlaModule:#hlo_module=jit_update,program_id=773#.
mariogeiger commented 10 months ago

Ok I can see in https://github.com/google/lineax/blob/main/lineax/_solve.py#L234 that it comes from the jvp rule. I guess it's hard/impossible to properly handle errors in the backward pass :/

I understand why you did like that, that's kind of neat

patrick-kidger commented 10 months ago

Yeah, unfortunately there's no good place to pipe errors to if they occur in the backward pass.

(Unrelatedly, if you have an equinox.filter_jit up at the top level, then I think it should filter out a lot of the noise in that error message.)