jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.48k stars 2.8k forks source link

Feature request: Support python control flow in `custom_transforms` functions #1275

Closed lxuechen closed 4 years ago

lxuechen commented 5 years ago

For fitting parameter values for ODEs a la the adjoint sensitivity method, we might want to override the gradient computation for the forward ODE solve. More concretely, we might have an integrator function odeint that takes in a gradient field f, initial state y0, and a sequence of times ts to be evaluated at.

One specific use case where supporting control-flow in custom_transforms will be useful is for the backward integration (which might involve adaptive solvers, hence non-trivial control-flow). Ideally, we would like to write code as follows

@custom_transforms
def odeint(y0, ts):
  pass  # Some procedure integrating the vector field `f`.

def vjp_y0(g, ans, y0, ts):
  pass # A while loop and some if statements used to determine integration step size.

defvjp(odeint, vjp_y0, None)
jacobjinkelly commented 5 years ago

Does #1255 fix this?

lxuechen commented 5 years ago

Correct me if I'm wrong, but the custom vjp functions here still aren't using Python control-flow. I understand that there are workarounds using the lax-based control-flow, but it would be nice if could just write Python loops and if-statements.

This is perhaps not necessary for the ODE adjoint, and the example code doesn't seem that complicated. But it might present some overhead when the research we're trying to implement has more complicated control-flow.

lxuechen commented 5 years ago

Unsurprisingly, it seems that unless you want to use tf.function or graph-mode execution, writing a few Python if-statements and loops in the gradient function used for tf.custom_gradient doesn't cause any trouble for TF2.0.

Being able to do this in JAX would be quite helpful to a researcher where in many cases the goal is to test out ideas fast.

shoyer commented 5 years ago

One option for using arbitrary control flow is to write new JAX primitives: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html

mattjj commented 5 years ago

We’re working on this. The reason it isn’t as simple as it was for Autograd is that JAX uses a new autodiff design in which we only have forward mode and derive reverse mode automatically (composing forward mode with other transformations). That confers several advantages, but a disadvantage is that since the system itself doesn’t work in terms of VJPs, supporting custom VJPs is tricky. (You can write custom JVPs, ie forward-mode rules, with arbitrary Python control flow now.)

mattjj commented 4 years ago

2026 finally landed and added support for Python control flow in custom derivative rules! (The API also changed, so take a look at the tutorial notebook.)

lxuechen commented 4 years ago

Thanks for consistently pushing this forward and the amazing work!