google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.71k stars 194 forks source link

Updates dtype do not need to match params dtype but only grads dtype a priori #1098

Open vroulet opened 1 month ago

vroulet commented 1 month ago

Revise #1060 in light of the discussion of #1039. Namely, make tests ensure that dtype of grads is preserved (not necessarily the same as the dtype of params for e.g. mixed precision training). Revise the patch #1060 to see if the initialization of the dtypes in the states are too stringent or not.