google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.68k stars 190 forks source link

Error when using `optax.Multisteps` with `optax.contrib.schedule_free` #1038

Closed ozanarmagan closed 1 month ago

ozanarmagan commented 2 months ago

Hello,

I am getting the following exception when I try to wrap a schedule free optimizer with multisteps. Can you help me?

        learnin_rate_fn = optax.schedules.warmup_constant_schedule(peak_value=config.learning_rate, warmup_steps=config.num_warmup_steps, init_value=0.0)
        optimizer = optax.adamw(learning_rate=learnin_rate_fn, b1=0.)
        optimizer = optax.contrib.schedule_free(optimizer, learning_rate=learnin_rate_fn, b1=config.b1, state_dtype=jnp.bfloat16)
        optimizer = optax.MultiSteps(optimizer, every_k_schedule=config.accum_steps)

Exception message: ` jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/home/ozan/cloud/train/trainer.py", line 116, in train_step new_state = state.apply_gradients(grads=grads, train_rngs=new_train_rngs) File "/home/ozan/.local/lib/python3.10/site-packages/flax/training/train_state.py", line 101, in apply_gradients updates, new_opt_state = self.tx.update( File "/home/ozan/.local/lib/python3.10/site-packages/optax/transforms/_accumulation.py", line 380, in update new_updates, new_state = lax.cond( TypeError: true_fun and false_fun output must have identical types, got ({'BERT_0': {'embedding_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'encoders_0': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'out': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[8,64,512]) vs. ShapedArray(float32[8,64,512])'}, 'query': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'value': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}}, 'encoder_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'ffn': {'layers_0': {'bias': 'DIFFERENT ShapedArray(bfloat16[2048]) vs. ShapedArray(float32[2048])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])'}, 'layers_3': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])'}}}, 'encoders_1': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'out': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[8,64,512]) vs. ShapedArray(float32[8,64,512])'}, 'query': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'value': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}}, 'encoder_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'ffn': {'layers_0': {'bias': 'DIFFERENT ShapedArray(bfloat16[2048]) vs. ShapedArray(float32[2048])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])'}, 'layers_3': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])'}}}, 'encoders_2': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'out': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[8,64,512]) vs. ShapedArray(float32[8,64,512])'}, 'query': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'value': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}}, 'encoder_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'ffn': {'layers_0': {'bias': 'DIFFERENT ShapedArray(bfloat16[2048]) vs. ShapedArray(float32[2048])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])'}, 'layers_3': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])'}}}, 'encoders_3': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'out': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[8,64,512]) vs. ShapedArray(float32[8,64,512])'}, 'query': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'value': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}}, 'encoder_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'ffn': {'layers_0': {'bias': 'DIFFERENT ShapedArray(bfloat16[2048]) vs. ShapedArray(float32[2048])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])'}, 'layers_3': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])'}}}, 'encoders_4': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'out': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[8,64,512]) vs. ShapedArray(float32[8,64,512])'}, 'query': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'value': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}}, 'encoder_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'ffn': {'layers_0': {'bias': 'DIFFERENT ShapedArray(bfloat16[2048]) vs. ShapedArray(float32[2048])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])'}, 'layers_3': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])'}}}, 'encoders_5': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])'....`

vroulet commented 2 months ago

Are your parameters in bfloat16 or in float32? Did you try setting state_dtype=jnp.float32 if it is the latter case?

(I know you probably tried but just in case, before digging into that more).

Also pinging the author of this code @nullstring

ozanarmagan commented 2 months ago

My parameters are in bfloat16, state_dtype is float32 by default, I also tried setting it to bfloat16 just like the code snippet I shared. It didn't make any difference.

kishorenc commented 2 months ago

I think it's the same issue as: https://github.com/google-deepmind/optax/issues/377#issuecomment-1662951572

clementpoiret commented 2 months ago

I don't don't if that's the exact same bug, but it also happens when using apply_if_finite:

learning_rate_fn = optax.warmup_constant_schedule(peak_value=retuned_lr)
optimizer = optax.adam(learning_rate_fn, b1=0.)
optimizer = optax.contrib.schedule_free(optimizer, learning_rate_fn, b1=b1)

optimizer = optax.apply_if_finite(optimizer, 5)

I got an error telling me one can't cast NoneType to float32. If that's a different error, I'll open a separate issue.

vroulet commented 1 month ago

@clementpoiret, this is a separated issue. Can you (i) sync with head (with the latest commit that solved the original issue posted here), (ii) test whether the bug has been solved or not, (iii) if not, open an issue?

clementpoiret commented 1 month ago

Sure, I'll test that! Thanks 👍