google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
939 stars 66 forks source link

Ways to address catastrophic oscillation #335

Closed NeilGirdhar closed 2 years ago

NeilGirdhar commented 2 years ago

I am trying to replace my hand-crafted gradient descent algorithm jaxopt, but I'm getting catastrophic oscillation in Jaxopt.

I think the reason is that my energy function is something like E(x) = f(x) + k x^2/2 where f is convex. Jaxopt uses the energy gradient

E'(x) = f'(x) + kx,

and then updates

next_x = x - E'(x) * stepsize.

In my version of gradient descent, I used the leaky integral

next_x = x exp(-k stepsize) - f'(x)/k (1 - exp(-k stepsize))

My version will never oscillate due to k being too large.

What's strange is that if I set f(x)=0, then there is no catastrophic oscillation. I should also mention that f has a custom JVP, but I don't think that should matter since both Jaxopt and my custom optimizer only look at f' (afaik).

Is there any way I can do something similar in Jaxopt? Should I have to? It would be much more elegant to just specify the energy function.