Closed lxuechen closed 4 years ago
Does #1255 fix this?
Correct me if I'm wrong, but the custom vjp functions here still aren't using Python control-flow. I understand that there are workarounds using the lax-based control-flow, but it would be nice if could just write Python loops and if-statements.
This is perhaps not necessary for the ODE adjoint, and the example code doesn't seem that complicated. But it might present some overhead when the research we're trying to implement has more complicated control-flow.
Unsurprisingly, it seems that unless you want to use tf.function
or graph-mode execution, writing a few Python if-statements and loops in the gradient function used for tf.custom_gradient
doesn't cause any trouble for TF2.0.
Being able to do this in JAX would be quite helpful to a researcher where in many cases the goal is to test out ideas fast.
One option for using arbitrary control flow is to write new JAX primitives: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
We’re working on this. The reason it isn’t as simple as it was for Autograd is that JAX uses a new autodiff design in which we only have forward mode and derive reverse mode automatically (composing forward mode with other transformations). That confers several advantages, but a disadvantage is that since the system itself doesn’t work in terms of VJPs, supporting custom VJPs is tricky. (You can write custom JVPs, ie forward-mode rules, with arbitrary Python control flow now.)
Thanks for consistently pushing this forward and the amazing work!
For fitting parameter values for ODEs a la the adjoint sensitivity method, we might want to override the gradient computation for the forward ODE solve. More concretely, we might have an integrator function
odeint
that takes in a gradient fieldf
, initial statey0
, and a sequence of timests
to be evaluated at.One specific use case where supporting control-flow in
custom_transforms
will be useful is for the backward integration (which might involve adaptive solvers, hence non-trivial control-flow). Ideally, we would like to write code as follows