gehring / fax

MIT License
78 stars 9 forks source link

Missing doc #13

Closed pierrelux closed 4 years ago

pierrelux commented 4 years ago

https://github.com/gehring/fax/blob/5b54a85d7d473aedb5fb161a6e47ad4073b9647e/fax/implicit/twophase.py#L86

pierrelux commented 4 years ago

Should explain the purpose of ``forward_solver''. A useful feature which otherwise goes unnoticed and defaults to successive approximation which may not apply in all cases

pierrelux commented 4 years ago

Can perhaps be fixed as part of #26

gehring commented 4 years ago

Can you check if the new updated docs in #26 seem good enough? I've also added an example in the readme.

pierrelux commented 4 years ago

I'm not sure that I understand what is described in:

The first solver is used to solve for the parametric fixed point and every subsequent solver is used recursively to compute arbitrary order derivatives. Each solver should be a callable taking in a parametric function (i.e., a callable which returns a callable) for which we seek a fixed-point, the initial guess, and the parameters to use. Except for the first solvers, the function given as first argument to solvers will not be param_func but a VJP function derived from param_func.

https://github.com/gehring/fax/blob/7aa870420a3ee00157101937d04bbc6cd8ef4dc6/fax/implicit/twophase.py#L62-L73

Does this mean that we now have something like ``(forward_solver, backward_solver)'' ? You can perhaps add a more explicit function signature as part of the doc because I'm not sure that I would be able to implement what's described without studying the code.

gehring commented 4 years ago

Does this mean that we now have something like ``(forward_solver, backward_solver)'' ?

Yes, as well as (forward_solver,), (forward_solver, backward_solver, second_order_back_solver), (forward_solver, backward_solver, second_order_back_solver, ...) are all accepted. Now come to think of it, we might want to also allow None in case we want to implement a custom forward solver and 2nd order backwards solver. Also, note that I've updated the readme to include an example of using a custom solver.

Re-reading what I wrote and I'm thinking that maybe using the term "recursive" here is more confusing than informative. One major change in this implementation of the two phase method is that the backwards pass is solved with a recursive call to two_phase_solve. This mean we support higher-order derivatives out of the box but require a more flexible api for passing in custom solvers.

This part picks the solver to use:

    # if solvers is empty, use the default fixed-point iteration solver
    if solvers:
        fwd_solver = solvers[0]
    else:
        fwd_solver = default_solver()

And this portion of the backwards pass "pops" the solver that was just used and passes through the rest to the recursive call:

    dsol = two_phase_solve(param_dfp_fn,
                           sol_bar,
                           (sol, params, sol_bar),
                           solvers[1:])

We might want to move this discussion to the PR so that we can comment inline and more conveniently track changes.

pierrelux commented 4 years ago

Thanks! Thanks for the update in readme.md. Is it in #26? I think that we can then close this issue.

gehring commented 4 years ago

Yes, you can find the current version of the new readme here. I've also updated the docs for solver after reading your previous reply.