Closed ozanarmagan closed 1 month ago
Are your parameters in bfloat16 or in float32?
Did you try setting state_dtype=jnp.float32
if it is the latter case?
(I know you probably tried but just in case, before digging into that more).
Also pinging the author of this code @nullstring
My parameters are in bfloat16, state_dtype
is float32 by default, I also tried setting it to bfloat16 just like the code snippet I shared. It didn't make any difference.
I think it's the same issue as: https://github.com/google-deepmind/optax/issues/377#issuecomment-1662951572
I don't don't if that's the exact same bug, but it also happens when using apply_if_finite
:
learning_rate_fn = optax.warmup_constant_schedule(peak_value=retuned_lr)
optimizer = optax.adam(learning_rate_fn, b1=0.)
optimizer = optax.contrib.schedule_free(optimizer, learning_rate_fn, b1=b1)
optimizer = optax.apply_if_finite(optimizer, 5)
I got an error telling me one can't cast NoneType to float32. If that's a different error, I'll open a separate issue.
@clementpoiret, this is a separated issue. Can you (i) sync with head (with the latest commit that solved the original issue posted here), (ii) test whether the bug has been solved or not, (iii) if not, open an issue?
Sure, I'll test that! Thanks 👍
Hello,
I am getting the following exception when I try to wrap a schedule free optimizer with multisteps. Can you help me?
Exception message: ` jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/home/ozan/cloud/train/trainer.py", line 116, in train_step new_state = state.apply_gradients(grads=grads, train_rngs=new_train_rngs) File "/home/ozan/.local/lib/python3.10/site-packages/flax/training/train_state.py", line 101, in apply_gradients updates, new_opt_state = self.tx.update( File "/home/ozan/.local/lib/python3.10/site-packages/optax/transforms/_accumulation.py", line 380, in update new_updates, new_state = lax.cond( TypeError: true_fun and false_fun output must have identical types, got ({'BERT_0': {'embedding_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'encoders_0': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'out': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[8,64,512]) vs. ShapedArray(float32[8,64,512])'}, 'query': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'value': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}}, 'encoder_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'ffn': {'layers_0': {'bias': 'DIFFERENT ShapedArray(bfloat16[2048]) vs. ShapedArray(float32[2048])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])'}, 'layers_3': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])'}}}, 'encoders_1': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'out': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[8,64,512]) vs. ShapedArray(float32[8,64,512])'}, 'query': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'value': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}}, 'encoder_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'ffn': {'layers_0': {'bias': 'DIFFERENT ShapedArray(bfloat16[2048]) vs. ShapedArray(float32[2048])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])'}, 'layers_3': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])'}}}, 'encoders_2': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'out': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[8,64,512]) vs. ShapedArray(float32[8,64,512])'}, 'query': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'value': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}}, 'encoder_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'ffn': {'layers_0': {'bias': 'DIFFERENT ShapedArray(bfloat16[2048]) vs. ShapedArray(float32[2048])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])'}, 'layers_3': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])'}}}, 'encoders_3': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'out': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[8,64,512]) vs. ShapedArray(float32[8,64,512])'}, 'query': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'value': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}}, 'encoder_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'ffn': {'layers_0': {'bias': 'DIFFERENT ShapedArray(bfloat16[2048]) vs. ShapedArray(float32[2048])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])'}, 'layers_3': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])'}}}, 'encoders_4': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'out': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[8,64,512]) vs. ShapedArray(float32[8,64,512])'}, 'query': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'value': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}}, 'encoder_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'ffn': {'layers_0': {'bias': 'DIFFERENT ShapedArray(bfloat16[2048]) vs. ShapedArray(float32[2048])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])'}, 'layers_3': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])'}}}, 'encoders_5': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])'....`