My version will never oscillate due to k being too large.
What's strange is that if I set f(x)=0, then there is no catastrophic oscillation. I should also mention that f has a custom JVP, but I don't think that should matter since both Jaxopt and my custom optimizer only look at f' (afaik).
Is there any way I can do something similar in Jaxopt? Should I have to? It would be much more elegant to just specify the energy function.
I am trying to replace my hand-crafted gradient descent algorithm jaxopt, but I'm getting catastrophic oscillation in Jaxopt.
I think the reason is that my energy function is something like E(x) = f(x) + k x^2/2 where f is convex. Jaxopt uses the energy gradient
E'(x) = f'(x) + kx,
and then updates
next_x = x - E'(x) * stepsize.
In my version of gradient descent, I used the leaky integral
next_x = x exp(-k stepsize) - f'(x)/k (1 - exp(-k stepsize))
My version will never oscillate due to k being too large.
What's strange is that if I set f(x)=0, then there is no catastrophic oscillation. I should also mention that f has a custom JVP, but I don't think that should matter since both Jaxopt and my custom optimizer only look at f' (afaik).
Is there any way I can do something similar in Jaxopt? Should I have to? It would be much more elegant to just specify the energy function.