Closed TheMr33 closed 6 months ago
I think the issue is that MultiStep is trying to average the OptState (https://github.com/google-deepmind/optax/blob/2e92d570c784de4b357dce83180021e658a9210f/optax/_src/wrappers.py#L407) , but in the case optax.zero_nans()
has as a state a boolean mask which gets converted to integers with the averaging operation
one potential solution would be to replace lambda st, nst: (1 - emit) * st + emit * nst,
with lambda st, nst: ((1 - emit) * st + emit * nst).astype(st.dtype),
to make sure the dtype doesn't change, but I wonder if that would generate other unintended consequences.
In particular, I think this solution will fail if the leafs are scalars instead of arrays
Ah, since emit is a boolean, I think there's a better solution.
@TheMr33 can you confirm that #840 solves your issue?
Hi @fabianp, thank you for the answer.
Your first answer (with the lambda) and your latest commit (6de95bf) both seem to solve my problem.
I haven't checked the consistency of the gradient, but the code no longer crashes.
excellent, thanks! Closing this one :-)
Hello,
There seems to be an incompatibility when using
optax.zero_nans()
withoptax.MultiSteps
.I replicate my issue starting from the gradient_accumulation example notebook: