Open Sea-Snell opened 2 years ago
Interesting, based on your description this would only happen if the dtype inference in line 383 results in the wrong type so I could try looking into whether the dtype returned from optax.scale_by_factored_rms
is correct. Do you have a minimal example of the error I could try this with?
Thanks a lot for raising this!
@Sea-Snell, as @mkunesch mentioned it would be helpful to have a minimal example we could try this with
@Sea-Snell I think this issue is not fixed and should be reopened.
Repro:
import os; os.environ['JAX_PLATFORMS'] = 'cpu'
import jax
import jax.numpy as jnp
import optax
@jax.jit
@jax.value_and_grad
def f(params, x, labels):
logits = params @ x
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
return loss.mean()
params = jnp.zeros((5, 18), dtype=jnp.bfloat16)
x = jnp.zeros((18, 4), dtype=jnp.bfloat16)
labels = jnp.zeros((5,), dtype=jnp.uint16)
value, grad = f(params, x, labels)
lr = 0.00005
n_accumulation_steps = 4
optimizer = optax.adafactor(learning_rate=lr)
optimizer = optax.MultiSteps(optimizer, n_accumulation_steps)
opt_state = optimizer.init(params)
updates, opt_state = optimizer.update(grad, opt_state, params)
print(updates)
Error:
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
File "/home/ayaka/llama-2-jax/1.py", line 24, in <module>
updates, opt_state = optimizer.update(grad, opt_state, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/optax/_src/wrappers.py", line 423, in update
new_updates, new_state = jax.lax.cond(
^^^^^^^^^^^^^
File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/lax/control_flow/conditionals.py", line 286, in cond
return _cond_with_per_branch_args(*ba.args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/lax/control_flow/conditionals.py", line 307, in _cond_with_per_branch_args
return _cond(pred,
^^^^^^^^^^^
File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/lax/control_flow/conditionals.py", line 251, in _cond
_check_tree_and_avals("true_fun and false_fun output",
File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/lax/control_flow/common.py", line 202, in _check_tree_and_avals
raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: true_fun and false_fun output must have identical types, got
('DIFFERENT ShapedArray(bfloat16[5,18]) vs. ShapedArray(float32[5,18])', MultiStepsState(mini_step='ShapedArray(int32[])', gradient_step='ShapedArray(int32[])', inner_opt_state=(FactoredState(count='ShapedArray(int32[])', v_row='ShapedArray(float32[1])', v_col='ShapedArray(float32[1])', v='ShapedArray(float32[5,18])'), EmptyState(), EmptyState(), EmptyState(), EmptyState()), acc_grads='ShapedArray(bfloat16[5,18])', skip_state=())).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ayaka/llama-2-jax/1.py", line 24, in <module>
updates, opt_state = optimizer.update(grad, opt_state, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/optax/_src/wrappers.py", line 423, in update
new_updates, new_state = jax.lax.cond(
^^^^^^^^^^^^^
TypeError: true_fun and false_fun output must have identical types, got
('DIFFERENT ShapedArray(bfloat16[5,18]) vs. ShapedArray(float32[5,18])', MultiStepsState(mini_step='ShapedArray(int32[])', gradient_step='ShapedArray(int32[])', inner_opt_state=(FactoredState(count='ShapedArray(int32[])', v_row='ShapedArray(float32[1])', v_col='ShapedArray(float32[1])', v='ShapedArray(float32[5,18])'), EmptyState(), EmptyState(), EmptyState(), EmptyState()), acc_grads='ShapedArray(bfloat16[5,18])', skip_state=())).
However, these modifications work:
optax.adafactor
to optax.adamw
optax.MultiSteps
I had the same problem. I have found that it happens because adafactor
returns float32
updates despite params and gradients being bfloat16
, while MultiSteps expects them to be of the same type when applying jax.lax.cond
. This happens because scale_by_factored_rms
inside adafactor
does not preserve the type of updates propagating through it. A lot of variables in it's internal state are float32
.
One quick fix is to add explicit conversion update.astype(grad.dtype)
to this line. If it sounds good, I'd be glad to submit a PR.
As @mk-0 mentioned, this is due to the jax.lax.cond
inside MultiSteps. As the error also occurs when using flax.training.dynamic_scale
, I think a fix inside MultiSteps would be better. The error occurs basically every time some values of the gradient are of different type compared to the corresponding parameters, i.e. often when some kind of scaling is applied which requires to cast bfloat16
to float32
. I'll open a PR with a possible fix.
If I use Adafactor with MultiStep on a bfloat16 model I get this strange error (note the error is extremely long, so I truncated it to fit in the issue; the model is T5-small):
The error points to this line of optax.MultiSteps. It's essentially saying that
mid_step
's first return value has type fp32 butfinal_step
has type bfloat16. If I force-castmid_step
's return to bfloat16, the error goes away. And looking at the code, I'm not exactly sure why this would happen; the code looks like it should handle the types correctly. So if anyone has an explanation or a non-hacky fix that would be appreciated.Note that optimizer is being called inside of a pjit on TPUv3. And I don't get this error with AdamW+MultiStep+bfloat16.