gehring / fax

MIT License
78 stars 9 forks source link

Fixed-point vs zero form #19

Open pierrelux opened 4 years ago

pierrelux commented 4 years ago

I'm not sure what's the best design decision, but it may be confusing for users to have to express a "zero problem" into a "fixed-point" one. Namely, if you want to specify $F(x, \theta) = 0$, the user would have to define a dummy function of the form $x = x + F(x, \theta)$ to make it compatible with the two-phase interface. We already have a lot of flags, so I'm not sure that adding one more is the right solution. If there is no programmatic solution, we should at least highlight it in the doc or with examples.

gehring commented 4 years ago

I'll have to think a bit more carefully about this but is using the two-phase method on non-fixed point problems advisable? At this point it might be better to use jax.lax.custom_root and use a possibly more efficient gradient solver. WDYT?

pierrelux commented 4 years ago

I tend to put the FP and root finding perspective in the same basket. I think the only difference is that our interface defaults to successive approximation if no forward solver is passed.

pierrelux commented 4 years ago

We can perhaps provide a decorator. Instead of having to write:

def sqrt_zero(a):
  def _sqrt_zero(x):
    return x + (a - x**2)
  return _sqrt_zero

Provide a decorator zero_problem that you use as:

def sqrt_zero(a):
  @zero_problem
  def _sqrt_zero(x):
    return a - x**2
  return _sqrt_zero
pierrelux commented 4 years ago

Here's more complete example:

import jax.numpy as jnp
from fax import implicit 

def root_finding(param_func):
  def _fp_param_func(params):
    fn_op = param_func(params)
    def _fp_operator(x):
      return x + fn_op(x)
    return _fp_operator
  return _fp_param_func

@root_finding
def square_root(a):
  def _square_root(x):
    return a - x**2
  return _square_root

print(implicit.two_phase_solve(square_root, init_xs=jnp.array(1.), params=jnp.array(4), solvers=(lambda f, x0, a: jnp.sqrt(a),)))
# Returns 2.0
gehring commented 4 years ago

The more I think about the more I feel like transforming a root finding problem in to a fixed-point finding problem to be impractical for the only reason that our two phase fixed-point implementation assumes we are dealing with an attractive fixed-point. If we provide some wrappers to solve root problems, it won't be obvious to the user when to expect the default solvers to work and when they shouldn't without a good understanding of what is happening behind the scene.

We could either 1) implement a general function for handling implicit differentiation for root finding with no default solvers like jax.lax.custom_root, and use the existing default solvers to re-implement two_phase_solve using that api, 2) maintain both a fixed-point specific method and a separate custom root method, which would duplicate some code but we would maintain a 1-to-1 correspondence between two_phase_solve and Christianson's method.

I think I am leaning towards 1) as long as we can keep the two_phase_solve API unchanged (or nearly unchanged) and we can keep supporting higher-order derivatives of fixed-points. WDYT?

Also, while we're on the topic, jax.lax.custom_root expects jax transformable solvers (as far as I can tell) which make it ill suited for non-jit'able cases where the solver might be some external program. Until jax natively supports XLA's custom_call op, we'll want to use our own implementation.

pierrelux commented 4 years ago

You make a good point regarding the use of successive approximation as a default "forward" solver, and how this is ill-suited and confusing for the user under the root-finding perspective. I made that mistake myself! If I understand correctly, your proposal is to ask the user to specify its own forward solver as a non-optional argument to two_phase_solve. I think it makes sense. The drawback is of course that it puts more burden on the user when invoking two_phase_solve, but it also makes things safer. If necessary, we can always provide a small abstraction/wrapper around two_phase_solve, specialized for each case.

I don't think that the code duplication route is the way to go here (and performance-wise, I'm confident in the XLA magic to close the potential gap between the fp backend vs pure root one).

pierrelux commented 4 years ago

It would be good to refactor out default_solver https://github.com/gehring/fax/blob/be058d17fb7d650ba1ebc093179a0f735a738054/fax/implicit/twophase.py#L13-L48 in this case. Also, this piece of code can be used in other contexts where we need to solve for lin. systems in a matrix-free fashion.

pierrelux commented 4 years ago

*what I meant was more the part about the solver derived from the Neumann series perspective. https://github.com/gehring/fax/blob/be058d17fb7d650ba1ebc093179a0f735a738054/fax/implicit/twophase.py#L112-L120