google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.71k stars 194 forks source link

fix: Coherent dtypes of updates with and without MultiSteps #1122

Closed hlzl closed 3 weeks ago

hlzl commented 1 month ago

Following #1039 and discussions with @vroulet a proposed fix such that updates with and without MultiSteps have the same dtypes independent of how acc_grads and the inner_opt_state were initialized.