Closed DanPuzzuoli closed 1 year ago
It seems the complex -> real casting warnings only ever occur in differentiated functions, which means this is probably some weird behaviour with JAX transformations (i.e. these warnings don't get raised if the function is not differentiated). Need to try to find a minimal example to understand what's going on.
Edit: Okay it turns out that this warning gets raised when computing the gradient of a function that involves a tensordot of a real array with a complex array. Will see what the JAX developers say about this: it seems like this should either be raised when not computing the gradient (for consistency), or should not be raised during gradient computation.
Asked here: https://github.com/google/jax/discussions/18133
As it currently stands:
dyson_magnus.py
in the function _setup_dyson_rhs_jax
.setup.py
and the test environment setups have been removed.test_dyson_magnus.py
that contains a minimal reproduction of the error with the original version of _setup_dyson_rhs_jax
. The test is setup to pass if an exception is thrown when executing this code. Once this test starts failing, we will know that the original version of _setup_dyson_rhs_jax
should work again, and at that point we can revert it. (The documentation in the test states this.)
Summary
Closes #190
As it is unclear when the bug causing an error in the perturbation module will get fixed in JAX, I just tried returning to the issue and have figured out a simple workaround to bypass the bug. I realistically should have tried this sooner but I didn't realize the issue would hang for this long.
The workaround fixes the perturbation module, but there are now many errors/warnings coming up in many different tests. This is to be expected - there have been over 10 minor releases of JAX since we put a bound on the version. I'll need to make my way through each folder and figure out how to fix it.
Details and comments
I'm working through each submodule to fix errors/warnings in the tests (strikethrough indicates no warnings/errors being raised when running tests, and no comment or strikethrough means the module hasn't been checked yet).
t_span
andt_eval
with JAXtest.dynamics.solvers.test_jax_odeint.TestJaxOdeint.test_transformations_t_eval_arg_overlap
. This is the result of a different bug in JAX (see this discussion). I need to review the status of what is supposed to be possible with this merging, but I don't think this is used heavily and only occurs ift_span[-1] == t_eval[-1]
. We could potentially just leave this as a known issue for now as this could be solved on the JAX side before the next release. Alternatively we could make it so thatjax_odeint
only takes int_span
, and the user should supply the full range of time values as if directly callingodeint
. This will mess with the interface a bit but this is a technicality that has caused many hours to be lost.alpha[0] = y0.conj().T @ projection
appears in the lanczos solver).~