patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Apache License 2.0
365 stars 24 forks source link

Help debugging solver failure #82

Closed FFroehlich closed 9 months ago

FFroehlich commented 9 months ago

I am calling lineax (version 0.0.4) as part of diffrax.diffeqsolve with Kvaerno5 solver using default options. This produces the following error message:

The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.

If you *were* expecting this solver to work with this operator, then it may be because:

(a) the operator is singular, and your code has a bug; or

(b) the operator was nearly singular (i.e. it had a high condition number:
    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
    numerical instability issues; or

(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
    that is does not actually satisfy.

I am trying to debug this using export EQX_ON_ERROR=breakpoint; export EQX_ON_ERROR_BREAKPOINT_FRAMES=6 which suggests that the failure happens during backpropagation. However, after inspecting all the variables I am struggling to understand what the actual problem is:

(jdb) list
> .../lineax/_solve.py(108)
        result = RESULTS.where(
            (result == RESULTS.successful) & has_nonfinites,
            RESULTS.singular,
            result,
        )
        if throw:
->          solution, result, stats = result.error_if(
                (solution, result, stats),
                result != RESULTS.successful,
            )
        return solution, result, stats
(jdb) pp state
(lu_token,
 ((array([[-1.72331707, -0.07344511],
       [ 0.        , -5.12813112]]),
   array([0, 1], dtype=int32)),
  Static(
  _leaves=[
    ShapeDtypeStruct(shape=(1, 2), dtype=float64),
    ShapeDtypeStruct(shape=(2,), dtype=float64),
    PyTreeDef((*, *))
  ],
  _treedef=PyTreeDef(([*, *], *))
),
  True))
(jdb) pp vector
array([[ 5.23672064e-11, -1.67091217e-13]])
(jdb) pp result
lineax._solution.RESULTS<>
(jdb) pp has_nonfinites
array(False)
(jdb) pp solution
array([-3.03874472e-11,  4.67792372e-13])
(jdb) pp stats
{}
(jdb) pp options
{}
(jdb) pp solver
AutoLinearSolver(well_posed=None)

Assuming array([[-1.72331707, -0.07344511],[ 0. , -5.12813112]]) is the matrix defining the system of linear equations, there shouldn't be any issue with the conditioning of the matrix, or the inputs/outputs. The solver appears to be an LU solver, for which the linear operator appears to satisfy all expected properties. Any pointers on how I could figure out what's going on?

patrick-kidger commented 9 months ago

Hmm. I agree, I also don't understand this. Note that the trigger condition for the error is result != RESULTS.successful, however

(jdb) pp result
lineax._solution.RESULTS<>

... otherwise known as RESULTS.successful!

Can you reduce this to a MWE? (Possibly even one that uses just Lineax, without Diffrax or Optimistix.)

_FWIW, technical side note: the matrix you're seeing there is not technically the input matrix; it is the LU decomposition of that (lu_token in the state is coming from AutoLinearSolver, and from there we can look up and see that the first element of the state is the decomposition. However by concidence it just so happens that actually this is also the input matrix: the upper triangle (including the diagonal) is the U of the decomposition, whilst the lower triangular (with unit diagonal) is the L of the decomposition. So in fact any upper triangular matrix will have its packed LU representation be the same as the original matrix._

FFroehlich commented 9 months ago

Hmm, after some digging it looks like EQX_ON_ERROR=breakpoint doesn't work well with jax.vmap. After reducing the problem such that compilation with a for loop was doable, the breakpoint correctly fired only on errors. This might be a bug in equinox, but I so far haven't managed to reduce the issues to something that could be called a MWE.

After all, the matrix in questions is indeed singular, so this issue doesn't appear to be a problem with lineax, but is likely caused upstream in diffrax.

patrick-kidger commented 9 months ago

Got it! Any MWEs you're able to put together, let me know. By the way, I'm quite concious of how many of these little issues you seem to be running into. If, once we've gotten to the bottom of any of them, you have any thoughts on how to make the process more robust, then I'd be very happy to try and improve things.

FFroehlich commented 9 months ago

Got it! Any MWEs you're able to put together, let me know. By the way, I'm quite concious of how many of these little issues you seem to be running into. If, once we've gotten to the bottom of any of them, you have any thoughts on how to make the process more robust, then I'd be very happy to try and improve things.

I think part of the issue is that I am still figuring out how to optimally debug code in jax. Both https://github.com/patrick-kidger/optimistix/pull/43 and https://github.com/google/jax/issues/16732 don't make this easier. One thing that would be helpful is have a debugging option to save inputs to high level functions such as diffeqsolve/rootfind/linear solves to disk to be able to isolate the issue. I have tried implementing this using jax.debug.callback, but have the impression that this led to additional weird side-effects.

Overall, my impression is that there is some weird bug in backpropagation with equinox/lineax/optimistix/diffrax that is only present when using jax.vmap that I haven't really been able to track down in an MWE. On this end it would be helpful to have a higher level documentation of how backprop is implemented in diffrax. I haven't had the chance to get back to the issues I encountered after fixing the issue above, but have a hunch that the error I reported in https://github.com/patrick-kidger/optimistix/issues/48 might be related.

patrick-kidger commented 9 months ago

Makes sense! Off the top of my head I don't know of anything else that really does transforms like JAX -- compiler optimisation passes are probably the closest reference point I have -- so how to best combine transforms with a debugger + navigating stack traces + etc., is something I think that I don't yet have a great answer for. I think you might be one of the first to explore the more difficult parts of that experience.

On jax.debug.callback, then I think https://github.com/google/jax/pull/17088 might help, so that we could pass vectorized=True. There was no good reason we didn't merge that, I just never got around to prodding Sharad again. If that's useful to you then maybe test that / resurrect it if so.

High-level documentation on backprop in Diffrax: I think the only thing worth mentioning on top of that is equinox.internal.while_loop(..., kind='checkpointed'), which does the recursive checkpointing. More docs on that could be worthwhile!