Open mannsean opened 3 years ago
It's great to see this use case! Topology optimization is very near to my heart and is exactly the same of thing we wanted to enable with GMRES.
It's not immediately obvious to me what is going on, but hopefully we can figure it out. It does look like a bug -- I think this is supposed to work (or least give a better error message).
I am curious why you need a custom JVP here. At first glance, what you're doing looks quite similar to implementing the implicit function theorem, which is exactly what lax.custom_root
does. Have you tried using custom_root
? I guess this could save you a few lines of code, but unfortunately it is also broken with GMRES, with a very similar error message (https://github.com/google/jax/pull/5321).
Great! Thanks for the suggestion -- I wrote the custom JVP with IFT in mind, but really wasn't aware of lax.custom_root
. It feels custom_root
is a little bit less flexible, hence messier to fit into our use case: keyword arguments aren't allowed, and from some quick tests it seems like I'm no longer allowed to use native Python loops and logic for the solve
and tangent_solve
supplied to custom_root
(using lax.while_loop
seems to slow the code down). Another thing is the linearization of f
-- as we implemented our own (sparse) Jacobian for f
, computing g
is unnecessary. Would you say using custom_root
brings any speedup relative to custom_jvp
?
Hope we can fix the bug, and thank you for developing JAX!
custom_root
is just a particular implementation of the IFT, also written using custom_jvp
. There is nothing wrong with rolling your own version, and indeed that might be desirable in some cases, e.g., if you want to use Python control flow.
@froystig I co-assigned you just now b/c this is relevant to stuff we've been talking about.
@romanodev and I are working on topology optimization using JAX’s newly released GMRES function. To avoid differentiating through the iterative solver, we are using
@jax.custom_jvp
with implicit functions. However, I am having trouble with reverse-mode AD, while both forward and reverse-mode work if we replace GMRES with something likenp.linalg.solve
. Here is a minimal example of the issue (to keep things tidy, I just have the identity function), while the full example is in the colab (https://colab.research.google.com/drive/1fIZgoB2zdMpErqi-q54k9CRsyUfIFbfp?usp=sharing). Any idea what is happening here? Thanks!