Open ro0mquy opened 2 months ago
Hello @ro0mquy,
I'd be happy to see how you handled it. I was not sure what would be the best solution to add this while keeping the API light. So if you have some example, I'd be happy to look at a PR.
Thanks!
Cool, I'll prepare a PR once I'm back from vacations in 1-2 weeks.On 10 Sep 2024, at 18:10, Vincent Roulet @.***> wrote: Hello @ro0mquy, I'd be happy to see how you handled it. I was not sure what would be the best solution to add this while keeping the API light. So if you have some example, I'd be happy to look at a PR. Thanks!
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you were mentioned.Message ID: @.***>
I have a loss function that returns
(loss_value, extra_data)
. Native jax supports this kind of construct withjax.value_and_grad(loss_fn, has_aux=True)
(doc). The differentiated function returns((loss_value, extra_data), grad)
.In optax, when using the linesearch algorithms (for example as part of L-BFGS), I can use
optax.value_and_grad_from_state(loss_fn)
(doc) which uses the optimizer state to save function evaluations done inside the linesearch. Unfortunately, the linesearch algorithms andoptax.value_and_grad_from_state
don't support auxiliary data.I added support for this to the optax code. It works for my use case. Are you interested in merging this upstream? I don't have time for proper testing, documentation, etc though, so would appreciate getting some assistance.