Open GeoffNN opened 2 years ago
For instance, tree_l2_norm would currently give incorrect results on complex parameters.
+1 on fixing this, thanks for catching
I think this is an issue we should tackle soon because as you said this could fail silently. Do you want to tackle it?
Hey! Sorry, I was interning this summer and off of github. I'll start checking this out!
Opening this as a nota bene. When optimizing over complex parameters, the gradient must be conjugated. Currently, all jaxopt optimizers would be incorrect on complex parameters, due to this.
Moreover, if any optimizer relies on a second order moments (eg Adam), it must also use the complex module squared instead of just the parameter squared. Current jaxopt solvers might be affected as well. I'm unsure of what implcit diff would do to complex parameters, but perhaps we could output a warning that it is currently probably incorrect.
I realized this while using Optax on a model with complex weights; thought it might be good to incorporate this in jaxopt solvers as well, as users might 1) not be aware of this and 2) it's really hard to debug on the user side.