google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.63k stars 174 forks source link

Adafactor + MultiStep with bfloat16 model doesn't work #377

Open Sea-Snell opened 2 years ago

Sea-Snell commented 2 years ago

If I use Adafactor with MultiStep on a bfloat16 model I get this strange error (note the error is extremely long, so I truncated it to fit in the issue; the model is T5-small):

Traceback (most recent call last):
  File "/home/charliesnell/jax_v_pytorch/large_lm_finetune/flax/main.py", line 135, in <module>
    train.unroll(metaconfig)
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/micro_config.py", line 39, in new_unroll
    result = unroll(self, metaconfig)
  File "/home/charliesnell/jax_v_pytorch/large_lm_finetune/flax/train_loop.py", line 372, in unroll
    logs, params, opt_state = p_step_fn(params, opt_state, new_rng, items)
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/experimental/pjit.py", line 352, in wrapped
    args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/experimental/pjit.py", line 330, in infer_params
    jaxpr, canonicalized_out_axis_resources_flat = _pjit_jaxpr(
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/experimental/pjit.py", line 490, in _pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, global_in_avals)
  File "/home/charliesnell/jax_v_pytorch/large_lm_finetune/flax/train_loop.py", line 337, in t5_step_fn
    updates, opt_state = optim.update(grads, opt_state, params)
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/optax/_src/wrappers.py", line 413, in update
    new_updates, new_state = jax.lax.cond(
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/conditionals.py", line 252, in cond
    return _cond_with_per_branch_args(*ba.args)
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/conditionals.py", line 273, in _cond_with_per_branch_args
    return _cond(pred,
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/conditionals.py", line 223, in _cond
    _check_tree_and_avals("true_fun and false_fun output",
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/common.py", line 105, in _check_tree_and_avals
    raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: true_fun and false_fun output must have identical types, got
(FrozenDict({
    decoder: {
        block: {
            0: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            relative_attention_bias: {
                                embedding: 'DIFFERENT ShapedArray(bfloat16[32,8]) vs. ShapedArray(float32[32,8])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        EncDecAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    2: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            1: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        EncDecAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    2: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            2: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        EncDecAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    2: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            3: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        EncDecAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    2: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            4: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        EncDecAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    2: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            5: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        EncDecAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    2: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
        },
        final_layer_norm: {
            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
        },
    },
    encoder: {
        block: {
            0: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            relative_attention_bias: {
                                embedding: 'DIFFERENT ShapedArray(bfloat16[32,8]) vs. ShapedArray(float32[32,8])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            1: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            2: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            3: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            4: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            5: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
        },
        final_layer_norm: {
            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
        },
    },
    shared: {
        embedding: 'DIFFERENT ShapedArray(bfloat16[32128,512]) vs. ShapedArray(float32[32128,512])',
    },
}), MultiStepsState(mini_step='ShapedArray(int32[])', gradient_step='ShapedArray(int32[])', inner_opt_state=(FactoredState(count='ShapedArray(int32[])', v_row=FrozenDict({
    decoder: {
        block: {
            0: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'ShapedArray(float32[512])',
                            },
                            o: {
                                kernel: 'ShapedArray(float32[512])',
                            },
                            q: {
                                kernel: 'ShapedArray(float32[512])',
                            },
                            relative_attention_bias: {
                                embedding: 'ShapedArray(float32[1])',
                            },
                            v: {
                                kernel: 'ShapedArray(float32[512])',
                            },
                        },
                        layer_norm: {
                            weight: 'ShapedArray(float32[1])',
                        },
                    },
                    1: {
                        EncDecAttention: {
                            k: {
                                kernel: 'ShapedArray(float32[512])',
                            },
                            o: {
                                kernel: 'ShapedArray(float32[512])',
                            },
                            q: {
                                kernel: 'ShapedArray(float32[512])',
                            },
                            v: {
                                kernel: 'ShapedArray(float32[512])',
                            },

The error points to this line of optax.MultiSteps. It's essentially saying that mid_step's first return value has type fp32 but final_step has type bfloat16. If I force-cast mid_step's return to bfloat16, the error goes away. And looking at the code, I'm not exactly sure why this would happen; the code looks like it should handle the types correctly. So if anyone has an explanation or a non-hacky fix that would be appreciated.

Note that optimizer is being called inside of a pjit on TPUv3. And I don't get this error with AdamW+MultiStep+bfloat16.

mkunesch commented 2 years ago

Interesting, based on your description this would only happen if the dtype inference in line 383 results in the wrong type so I could try looking into whether the dtype returned from optax.scale_by_factored_rms is correct. Do you have a minimal example of the error I could try this with?

Thanks a lot for raising this!

mtthss commented 2 years ago

@Sea-Snell, as @mkunesch mentioned it would be helpful to have a minimal example we could try this with

ayaka14732 commented 1 year ago

@Sea-Snell I think this issue is not fixed and should be reopened.

ayaka14732 commented 1 year ago

Repro:

import os; os.environ['JAX_PLATFORMS'] = 'cpu'
import jax
import jax.numpy as jnp
import optax

@jax.jit
@jax.value_and_grad
def f(params, x, labels):
    logits = params @ x
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    return loss.mean()

params = jnp.zeros((5, 18), dtype=jnp.bfloat16)
x = jnp.zeros((18, 4), dtype=jnp.bfloat16)
labels = jnp.zeros((5,), dtype=jnp.uint16)
value, grad = f(params, x, labels)

lr = 0.00005
n_accumulation_steps = 4

optimizer = optax.adafactor(learning_rate=lr)
optimizer = optax.MultiSteps(optimizer, n_accumulation_steps)
opt_state = optimizer.init(params)
updates, opt_state = optimizer.update(grad, opt_state, params)
print(updates)

Error:

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
  File "/home/ayaka/llama-2-jax/1.py", line 24, in <module>
    updates, opt_state = optimizer.update(grad, opt_state, params)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/optax/_src/wrappers.py", line 423, in update
    new_updates, new_state = jax.lax.cond(
                             ^^^^^^^^^^^^^
  File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/lax/control_flow/conditionals.py", line 286, in cond
    return _cond_with_per_branch_args(*ba.args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/lax/control_flow/conditionals.py", line 307, in _cond_with_per_branch_args
    return _cond(pred,
           ^^^^^^^^^^^
  File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/lax/control_flow/conditionals.py", line 251, in _cond
    _check_tree_and_avals("true_fun and false_fun output",
  File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/lax/control_flow/common.py", line 202, in _check_tree_and_avals
    raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: true_fun and false_fun output must have identical types, got
('DIFFERENT ShapedArray(bfloat16[5,18]) vs. ShapedArray(float32[5,18])', MultiStepsState(mini_step='ShapedArray(int32[])', gradient_step='ShapedArray(int32[])', inner_opt_state=(FactoredState(count='ShapedArray(int32[])', v_row='ShapedArray(float32[1])', v_col='ShapedArray(float32[1])', v='ShapedArray(float32[5,18])'), EmptyState(), EmptyState(), EmptyState(), EmptyState()), acc_grads='ShapedArray(bfloat16[5,18])', skip_state=())).

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ayaka/llama-2-jax/1.py", line 24, in <module>
    updates, opt_state = optimizer.update(grad, opt_state, params)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/optax/_src/wrappers.py", line 423, in update
    new_updates, new_state = jax.lax.cond(
                             ^^^^^^^^^^^^^
TypeError: true_fun and false_fun output must have identical types, got
('DIFFERENT ShapedArray(bfloat16[5,18]) vs. ShapedArray(float32[5,18])', MultiStepsState(mini_step='ShapedArray(int32[])', gradient_step='ShapedArray(int32[])', inner_opt_state=(FactoredState(count='ShapedArray(int32[])', v_row='ShapedArray(float32[1])', v_col='ShapedArray(float32[1])', v='ShapedArray(float32[5,18])'), EmptyState(), EmptyState(), EmptyState(), EmptyState()), acc_grads='ShapedArray(bfloat16[5,18])', skip_state=())).

However, these modifications work:

  1. Change the optimiser from optax.adafactor to optax.adamw
  2. Remove optax.MultiSteps
mk-0 commented 1 year ago

I had the same problem. I have found that it happens because adafactor returns float32 updates despite params and gradients being bfloat16, while MultiSteps expects them to be of the same type when applying jax.lax.cond. This happens because scale_by_factored_rms inside adafactor does not preserve the type of updates propagating through it. A lot of variables in it's internal state are float32.

One quick fix is to add explicit conversion update.astype(grad.dtype) to this line. If it sounds good, I'd be glad to submit a PR.

hlzl commented 6 days ago

As @mk-0 mentioned, this is due to the jax.lax.cond inside MultiSteps. As the error also occurs when using flax.training.dynamic_scale, I think a fix inside MultiSteps would be better. The error occurs basically every time some values of the gradient are of different type compared to the corresponding parameters, i.e. often when some kind of scaling is applied which requires to cast bfloat16 to float32. I'll open a PR with a possible fix.