Open jjyyxx opened 2 years ago
I'd be happy to offer this if you can determine an API for it. Because this happens inside a custom_vjp
then as far as I can see JAX doesn't offer anywhere to return auxiliary information.
If you want a hacky version then you could just modify the source code for BacksolveAdjoint
and print/save the stats as a side-effect using jax.experimental.host_callback
.
Thanks for your quick reply!
For the first way, it may relate to https://github.com/google/jax/issues/2796 and https://github.com/google/jax/pull/2574, which currently does not have an elegant solution. A dummy stats passed as argument proposed in https://github.com/google/jax/pull/2574 might work, but I'm not sure if you consider it worth implementing.
For the hacky way, I implemented a draft as follows
diff --git a/adjoint.py b/adjoint2.py
index 79dae8e..3658339 100644
--- a/adjoint.py
+++ b/adjoint2.py
@@ -10,6 +10,7 @@ from .misc import nondifferentiable_output, ω
from .saveat import SaveAt
from .term import AbstractTerm, AdjointTerm
+import jax.experimental.host_callback
class AbstractAdjoint(eqx.Module):
"""Abstract base class for all adjoint methods."""
@@ -180,7 +181,7 @@ def _loop_backsolve_bwd(
def _scan_fun(_state, _vals, first=False):
_t1, _t0, _y0, _grad_y0 = _vals
- _a0, _solver_state, _controller_state = _state
+ _a0, _solver_state, _controller_state, _stats = _state
_a_y0, _a_diff_args0, _a_diff_term0 = _a0
_a_y0 = (_a_y0**ω + _grad_y0**ω).ω
_aug0 = (_y0, _a_y0, _a_diff_args0, _a_diff_term0)
@@ -205,10 +206,12 @@ def _loop_backsolve_bwd(
_a1 = (_a_y1, _a_diff_args1, _a_diff_term1)
_solver_state = _sol.solver_state
_controller_state = _sol.controller_state
+ _sol.stats.pop("max_steps")
+ _stats = _sol.stats if _stats is None else jax.tree_map(lambda x, y: x + y, _stats, _sol.stats)
- return (_a1, _solver_state, _controller_state), None
+ return (_a1, _solver_state, _controller_state, _stats), None
- state = ((zeros_like_y, zeros_like_diff_args, zeros_like_diff_terms), None, None)
+ state = ((zeros_like_y, zeros_like_diff_args, zeros_like_diff_terms), None, None, None)
del zeros_like_y, zeros_like_diff_args, zeros_like_diff_terms
# We always start backpropagating from `ts[-1]`.
@@ -240,7 +243,7 @@ def _loop_backsolve_bwd(
val = (ts[0], ts[1], ω(ys)[1].ω, ω(grad_ys)[1].ω)
state, _ = _scan_fun(state, val, first=True)
- aug1, _, _ = state
+ aug1, _, _, _stats = state
a_y1, a_diff_args1, a_diff_terms1 = aug1
a_y1 = (ω(a_y1) + ω(grad_ys)[0]).ω
@@ -261,8 +264,9 @@ def _loop_backsolve_bwd(
val = (t0, ts[0], ω(ys)[0].ω, ω(grad_ys)[0].ω)
state, _ = _scan_fun(state, val, first=True)
- aug1, _, _ = state
+ aug1, _, _, _stats = state
a_y1, a_diff_args1, a_diff_terms1 = aug1
+ jax.experimental.host_callback.id_print(_stats)
return a_y1, a_diff_args1, a_diff_terms1
I'm unsure if I understand your code correctly. Could you help review it?
The hacky solution luckily did not have obvious negative impact on performance, but the biggest pain is that it's inconsistent with the overall code execution flow (e.g. logging to tensorboard) and requires some more hacks with id_tap
and my code outside jax.jit
.
I probably wouldn't mutate using pop
, but other than that it LGTM.
Something else that might make logging easier is using equinox.experimental.{get,set}_state
instead of host_callback
directly. See here. This wraps host_callback
to provide an interface for stateful operations; in this case saving data and retrieving it at a later time.
Thanks for this excellent library!
When using
BacksolveAdjoint
adjoint method, it's easy to get (and log) forward pass stats withI believe the backward pass also involves solving ODE, but I did not figure out a way to get its stats. Could you suggest how to achieve this functionality? Or is it something limited by JAX API?