gehring / fax

MIT License
78 stars 9 forks source link

Allow the two phase method to support higher order differentiation #25

Open gehring opened 4 years ago

gehring commented 4 years ago

With the (somewhat) recent changes to how jax handles custom VJPs, it is now possible to define derivatives using the function for which we are defining the derivative. Since the VJP for the fixed point can be computed by solving for another fixed point, a recursive implementation would allow higher order differentiation with little to no extra complexity.

pierrelux commented 4 years ago

There might be a conceptual issue with the introduction of the "solvers" argument in #26.

https://github.com/gehring/fax/blob/9701c79b64830dbc3c8b250a77394dfdc8ab6bca/fax/implicit/twophase.py#L134-L152

It looks like the only role played by potential custom "backward" is to solve the associated adjoint fixed-point equation using a potentially different successive approximation method. My previous understanding of this feature is that it would let you specify any solver for $\bar{x} (I - D_x f(x^\star, \theta))^{-1}$ for some arbitrary adjoint vector, irrespectively from the particular underlying algorithm (CG, GMRES).

gehring commented 4 years ago

As far as two_phase_solve is concerned, you are always solving fixed-point problems regardless of forwards/backwards/other. It formulates reverse differentiation as a fixed-point problem and recursively calls itself with one fewer solver. The solvers should return a fixed-point of whatever function it is given. You could give it a solver based on CG or GMRES if you wanted. I'm not quite sure I understand the issue you are trying to raise.

pierrelux commented 4 years ago

From the perspective of Christianson and the view where two_phase_solve is a FP thing only, then yes I think that imposing that the backward solver operates on the backward FP problem makes sense. However, we can also think of not imposing to solve for the backward problem as a FP problem. At the high level, you have a linear operator fp_vjp_fn and you want to find a dout such that fp_vjp_fn(dout) = dvalue. You can view that as a FP problem if you want, but this is also just a linear system of equations with the linear operator fp_vjp_fn and "right hand side" equals to `dvalue'.

gehring commented 4 years ago

I think I understand what you want. I don't know if I want to change the whole "differentiate a fixed-point by solving a fixed-point" paradigm of Christianson's two-phase method. Although far from elegant, you can still bake in your parametric function into your backwards solver and ignore the function that is passed while still using the other arguments provided (e.g., solved fixed-point, parameters, cotangent). It's awkward but the solvers argument still lets you exploit linearity.

However, I definitely agree that limiting our implicit methods to fixed-point problems is inconvenient if not outright limiting. To address this, I will be implementing a proper implicit differentiation backend which will be built around root finding. Once that's done, I'll refactor two_phase_solve to use the root finding backend and we can revisit the solvers argument. I'm also contemplating just deprecating two_phase_solve entirely in favor of more descriptive and explicit names, e.g., root_solve, fixed_point_solve. That would also address the issue of possibly misleading users by referencing Christianson's method but actually solving derivatives using a root finding formulation which I can't guarantee will be identical once XLA does its thing.