Closed FFroehlich closed 9 months ago
Hmm. I agree, I also don't understand this. Note that the trigger condition for the error is result != RESULTS.successful
, however
(jdb) pp result
lineax._solution.RESULTS<>
... otherwise known as RESULTS.successful
!
Can you reduce this to a MWE? (Possibly even one that uses just Lineax, without Diffrax or Optimistix.)
_FWIW, technical side note: the matrix you're seeing there is not technically the input matrix; it is the LU decomposition of that (lu_token
in the state is coming from AutoLinearSolver
, and from there we can look up and see that the first element of the state is the decomposition. However by concidence it just so happens that actually this is also the input matrix: the upper triangle (including the diagonal) is the U of the decomposition, whilst the lower triangular (with unit diagonal) is the L of the decomposition. So in fact any upper triangular matrix will have its packed LU representation be the same as the original matrix._
Hmm, after some digging it looks like EQX_ON_ERROR=breakpoint
doesn't work well with jax.vmap
. After reducing the problem such that compilation with a for loop was doable, the breakpoint correctly fired only on errors. This might be a bug in equinox
, but I so far haven't managed to reduce the issues to something that could be called a MWE.
After all, the matrix in questions is indeed singular, so this issue doesn't appear to be a problem with lineax
, but is likely caused upstream in diffrax
.
Got it! Any MWEs you're able to put together, let me know. By the way, I'm quite concious of how many of these little issues you seem to be running into. If, once we've gotten to the bottom of any of them, you have any thoughts on how to make the process more robust, then I'd be very happy to try and improve things.
Got it! Any MWEs you're able to put together, let me know. By the way, I'm quite concious of how many of these little issues you seem to be running into. If, once we've gotten to the bottom of any of them, you have any thoughts on how to make the process more robust, then I'd be very happy to try and improve things.
I think part of the issue is that I am still figuring out how to optimally debug code in jax. Both https://github.com/patrick-kidger/optimistix/pull/43 and https://github.com/google/jax/issues/16732 don't make this easier. One thing that would be helpful is have a debugging option to save inputs to high level functions such as diffeqsolve/rootfind/linear solves to disk to be able to isolate the issue. I have tried implementing this using jax.debug.callback
, but have the impression that this led to additional weird side-effects.
Overall, my impression is that there is some weird bug in backpropagation with equinox/lineax/optimistix/diffrax that is only present when using jax.vmap
that I haven't really been able to track down in an MWE. On this end it would be helpful to have a higher level documentation of how backprop is implemented in diffrax. I haven't had the chance to get back to the issues I encountered after fixing the issue above, but have a hunch that the error I reported in https://github.com/patrick-kidger/optimistix/issues/48 might be related.
Makes sense! Off the top of my head I don't know of anything else that really does transforms like JAX -- compiler optimisation passes are probably the closest reference point I have -- so how to best combine transforms with a debugger + navigating stack traces + etc., is something I think that I don't yet have a great answer for. I think you might be one of the first to explore the more difficult parts of that experience.
On jax.debug.callback
, then I think https://github.com/google/jax/pull/17088 might help, so that we could pass vectorized=True
. There was no good reason we didn't merge that, I just never got around to prodding Sharad again. If that's useful to you then maybe test that / resurrect it if so.
High-level documentation on backprop in Diffrax: I think the only thing worth mentioning on top of that is equinox.internal.while_loop(..., kind='checkpointed')
, which does the recursive checkpointing. More docs on that could be worthwhile!
I am calling
lineax
(version 0.0.4) as part ofdiffrax.diffeqsolve
withKvaerno5
solver using default options. This produces the following error message:I am trying to debug this using
export EQX_ON_ERROR=breakpoint; export EQX_ON_ERROR_BREAKPOINT_FRAMES=6
which suggests that the failure happens during backpropagation. However, after inspecting all the variables I am struggling to understand what the actual problem is:Assuming
array([[-1.72331707, -0.07344511],[ 0. , -5.12813112]])
is the matrix defining the system of linear equations, there shouldn't be any issue with the conditioning of the matrix, or the inputs/outputs. The solver appears to be an LU solver, for which the linear operator appears to satisfy all expected properties. Any pointers on how I could figure out what's going on?