google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.65k stars 181 forks source link

unitwise_norm fails for 3D convolutions #906

Open froody opened 6 months ago

froody commented 6 months ago

unitwise_norm, used by adaptive_grad_clip, only supports a few values of ndim, and raises ValueError when applied to a conv3d kernel since ndim=5 (HWDIO). Would it be acceptable to add an optional axis kwarg to adaptive_grad_clip and unitwise_norm? This would allow specifying the reduction axes at the callsite instead of baking every possible combination into the implementation of unitwise_norm.

I'm happy to submit a PR

vroulet commented 6 months ago

Hello @froody,

Good catch. The behavior of adaptive_grad_clip hides indeed some logic that could mislead users indeed. If you are willing to do a pr to let this function handle ndim=5 that would be great. I don't know exactly how you can add an axis and keep the current default behavior, I let you try and see :)

Thank you !