Closed pierrelux closed 4 years ago
Should explain the purpose of ``forward_solver''. A useful feature which otherwise goes unnoticed and defaults to successive approximation which may not apply in all cases
Can perhaps be fixed as part of #26
Can you check if the new updated docs in #26 seem good enough? I've also added an example in the readme.
I'm not sure that I understand what is described in:
The first solver is used to solve for the parametric fixed point and every subsequent solver is used recursively to compute arbitrary order derivatives. Each solver should be a callable taking in a parametric function (i.e., a callable which returns a callable) for which we seek a fixed-point, the initial guess, and the parameters to use. Except for the first solvers, the function given as first argument to solvers will not be
param_func
but a VJP function derived fromparam_func
.
Does this mean that we now have something like ``(forward_solver, backward_solver)'' ? You can perhaps add a more explicit function signature as part of the doc because I'm not sure that I would be able to implement what's described without studying the code.
Does this mean that we now have something like ``(forward_solver, backward_solver)'' ?
Yes, as well as (forward_solver,)
, (forward_solver, backward_solver, second_order_back_solver)
, (forward_solver, backward_solver, second_order_back_solver, ...)
are all accepted. Now come to think of it, we might want to also allow None
in case we want to implement a custom forward solver and 2nd order backwards solver. Also, note that I've updated the readme to include an example of using a custom solver.
Re-reading what I wrote and I'm thinking that maybe using the term "recursive" here is more confusing than informative. One major change in this implementation of the two phase method is that the backwards pass is solved with a recursive call to two_phase_solve
. This mean we support higher-order derivatives out of the box but require a more flexible api for passing in custom solvers.
This part picks the solver to use:
# if solvers is empty, use the default fixed-point iteration solver
if solvers:
fwd_solver = solvers[0]
else:
fwd_solver = default_solver()
And this portion of the backwards pass "pops" the solver that was just used and passes through the rest to the recursive call:
dsol = two_phase_solve(param_dfp_fn,
sol_bar,
(sol, params, sol_bar),
solvers[1:])
We might want to move this discussion to the PR so that we can comment inline and more conveniently track changes.
Thanks! Thanks for the update in readme.md. Is it in #26? I think that we can then close this issue.
https://github.com/gehring/fax/blob/5b54a85d7d473aedb5fb161a6e47ad4073b9647e/fax/implicit/twophase.py#L86