google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
926 stars 64 forks source link

Complex gradients #169

Open GeoffNN opened 2 years ago

GeoffNN commented 2 years ago

Opening this as a nota bene. When optimizing over complex parameters, the gradient must be conjugated. Currently, all jaxopt optimizers would be incorrect on complex parameters, due to this.

Moreover, if any optimizer relies on a second order moments (eg Adam), it must also use the complex module squared instead of just the parameter squared. Current jaxopt solvers might be affected as well. I'm unsure of what implcit diff would do to complex parameters, but perhaps we could output a warning that it is currently probably incorrect.

I realized this while using Optax on a model with complex weights; thought it might be good to incorporate this in jaxopt solvers as well, as users might 1) not be aware of this and 2) it's really hard to debug on the user side.

GeoffNN commented 2 years ago

For instance, tree_l2_norm would currently give incorrect results on complex parameters.

https://github.com/google/jaxopt/blob/eb6e75dfee1d25cc2b206ad1410668c576bf6750/jaxopt/_src/tree_util.py#L84

mblondel commented 2 years ago

+1 on fixing this, thanks for catching

mblondel commented 2 years ago

I think this is an issue we should tackle soon because as you said this could fail silently. Do you want to tackle it?

GeoffNN commented 2 years ago

Hey! Sorry, I was interning this summer and off of github. I'll start checking this out!