Open gehring opened 4 years ago
There might be a conceptual issue with the introduction of the "solvers" argument in #26.
It looks like the only role played by potential custom "backward" is to solve the associated adjoint fixed-point equation using a potentially different successive approximation method. My previous understanding of this feature is that it would let you specify any solver for $\bar{x} (I - D_x f(x^\star, \theta))^{-1}$ for some arbitrary adjoint vector, irrespectively from the particular underlying algorithm (CG, GMRES).
As far as two_phase_solve
is concerned, you are always solving fixed-point problems regardless of forwards/backwards/other. It formulates reverse differentiation as a fixed-point problem and recursively calls itself with one fewer solver. The solvers should return a fixed-point of whatever function it is given. You could give it a solver based on CG or GMRES if you wanted. I'm not quite sure I understand the issue you are trying to raise.
From the perspective of Christianson and the view where two_phase_solve
is a FP thing only, then yes I think that imposing that the backward solver operates on the backward FP problem makes sense. However, we can also think of not imposing to solve for the backward problem as a FP problem. At the high level, you have a linear operator fp_vjp_fn
and you want to find a dout
such that fp_vjp_fn(dout) = dvalue
. You can view that as a FP problem if you want, but this is also just a linear system of equations with the linear operator fp_vjp_fn
and "right hand side" equals to `dvalue'.
I think I understand what you want. I don't know if I want to change the whole "differentiate a fixed-point by solving a fixed-point" paradigm of Christianson's two-phase method. Although far from elegant, you can still bake in your parametric function into your backwards solver and ignore the function that is passed while still using the other arguments provided (e.g., solved fixed-point, parameters, cotangent). It's awkward but the solvers argument still lets you exploit linearity.
However, I definitely agree that limiting our implicit methods to fixed-point problems is inconvenient if not outright limiting. To address this, I will be implementing a proper implicit differentiation backend which will be built around root finding. Once that's done, I'll refactor two_phase_solve
to use the root finding backend and we can revisit the solvers
argument. I'm also contemplating just deprecating two_phase_solve
entirely in favor of more descriptive and explicit names, e.g., root_solve
, fixed_point_solve
. That would also address the issue of possibly misleading users by referencing Christianson's method but actually solving derivatives using a root finding formulation which I can't guarantee will be identical once XLA does its thing.
With the (somewhat) recent changes to how
jax
handles custom VJPs, it is now possible to define derivatives using the function for which we are defining the derivative. Since the VJP for the fixed point can be computed by solving for another fixed point, a recursive implementation would allow higher order differentiation with little to no extra complexity.