google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.64k stars 179 forks source link

Support for loss function with auxiliary data in linesearch #1053

Open ro0mquy opened 1 week ago

ro0mquy commented 1 week ago

I have a loss function that returns (loss_value, extra_data). Native jax supports this kind of construct with jax.value_and_grad(loss_fn, has_aux=True) (doc). The differentiated function returns ((loss_value, extra_data), grad).

In optax, when using the linesearch algorithms (for example as part of L-BFGS), I can use optax.value_and_grad_from_state(loss_fn) (doc) which uses the optimizer state to save function evaluations done inside the linesearch. Unfortunately, the linesearch algorithms and optax.value_and_grad_from_state don't support auxiliary data.

I added support for this to the optax code. It works for my use case. Are you interested in merging this upstream? I don't have time for proper testing, documentation, etc though, so would appreciate getting some assistance.

vroulet commented 1 week ago

Hello @ro0mquy,

I'd be happy to see how you handled it. I was not sure what would be the best solution to add this while keeping the API light. So if you have some example, I'd be happy to look at a PR.

Thanks!

ro0mquy commented 1 week ago

Cool, I'll prepare a PR once I'm back from vacations in 1-2 weeks.On 10 Sep 2024, at 18:10, Vincent Roulet @.***> wrote: Hello @ro0mquy, I'd be happy to see how you handled it. I was not sure what would be the best solution to add this while keeping the API light. So if you have some example, I'd be happy to look at a PR. Thanks!

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you were mentioned.Message ID: @.***>