Open itk22 opened 3 months ago
Hmm. I don't think we offer that at the moment!
FWIW if this is just for debugging purposes then you could add jax.debug.print
statements to the input or output of your function.
If you really want to interact with the history programmatically then (a) I'm quite curious what the use-case is, but also (b) we could probably add an additional optx.minimise(..., saveat=...)
argument without too much difficulty.
Hi,
Thanks for the quick response. I am currently trying to use optimistix to implement Model Agnostic Meta-Learning and its implicit version (https://arxiv.org/abs/1909.04630). I was considering using a multi_step_solve
approach from above for the outer meta-learning loop because it makes the training very quick. However, I need to be able to monitor the meta-losses, which is why I was looking into the less efficient single_step_solve
. The saveat
option would be very useful!
On a side note, I had a look at the interactive stepping example, and I thought that it could be useful for solvers to have an update
method for performing a single optimisation step, similar to JAXOpt. What do you think?
Makes sense. I'll mark this as a feature request for a saveat
option. (And I'm sure we'd be happy to take a PR on this.)
For performing a single optimisation steps, then I think we already have this, as the step
method on the solver?
Thanks Patrick, I will try to make a PR on this :)
Dear optimistix team,
First of all, thank you for your effort in developing optimistix. I have recently transitioned from JAXOpt, and I love it!
I was wondering if it is possible to extract the loss/ function value history from the optimistic solve? In the code example below, it is easy to evaluate the intermediate losses when using the
multi_step_solve
method, but it is much less efficient than the 'single_step_solve' approach. Using ajax.lax.scan
would definitely improve the performance over using afor
but I was wondering if there is a simpler method to extract this information in optimistix.