Open quattro opened 7 months ago
Hi all, thanks for the phenomenal library. We're already using it in several statistical genetics methods in my group!
That's great to hear, thank you!
On HVPs: if I understand you correctly, this is a general JAX question, rather than specifically a question of how to integrate a solve into Optimistix? You're looking to get both the gradient and a HVP without having to treat them both separately (which would be 3 sweeps in total). This approach will get you 2 sweeps, which I think is optimal:
def to_jvp(x):
return jax.value_and_grad(fn)(x)
(f, dfdxi), (_, dfdxidxjvj) = jax.jvp(to_jvp, (x,), (v,))
Yes exactly! It isn't clear to me how to work this into a Newton-like CG solver, but I'll keep toying around. Thanks as always for your enthusiasm and detailed help. It is greatly appreciated!
Newton CG is one of the algorithms which I think could be somewhat involved to implement. It is further from the existing solvers in Optimistix, so there's more custom work that needs to be done to get it running.
There's two steps I would take if I were implementing it (which I may in the future):
lx.FunctionLinearOperator
in FunctionInfo.EvalGradHessian
. Take a look at the abstract solvers (AbstractGradientDescent
, AbstractGaussNewton
, AbstractBFGS
, etc.) for a bit on how to do this, and check out the docs for FunctionInfo
. Implementing a top-level solver from scratch for the first time can be a little tricky, because it requires touching all of the abstractions used in Optimistix. However, the technical bit should be very simple here. If you find it difficult to figure out from looking at existing implementations just poke me!not_converged
(something like tree_dot(y, operator.mv(y)) > 0
. This may need tweaking, and you'll also have to disable the positive semidefinite checks.)This should be pretty much everything though. Just use optx.NewtonDescent
with linear_solver=NewEarlyExitCG
for the descent, and your favorite Optimistix Search
for the search and it should all work! You could probably use some other descents as well, such as DoglegDescent
or DampedNewtonDescent
by passing the new linear solver in to these descents. I don't know how well these would work, as they're not standard algorithms (neat!)
Hi all, thanks for the phenomenal library. We're already using it in several statistical genetics methods in my group!
I've been porting over some older code of mine to use optimistix, rather than hand-rolled inference procedures and could use some advice. Currently, I am performing some variational inference using a mix of closed-form updates for variational parameters, as well as gradient-based updates for some hyperparameters. It -roughly- works like,
I'd -like- to retool the above to not only report the current value, aux values (i.e. updated variational parameters), and gradient wrt hyper param, but return a -hvp- function that could be used in a Newton CG like step in Optimistix. I know of the new
minimize
function, but what isn't clear is how to set up the scenario to not only report gradients, but also return ahvp
function internally without having to take two additional passes over the graph (i.e. once for value and grad, another two for hvp => forward + backward).Is this doable? Apologies if this is somewhat nebulous--I'm happy to clarify.