Open colehaus opened 3 months ago
Hi - thanks for the question! I spent some time making a more concise reproduction here
import jax
def check_err(x, y):
result = x + y
y2 = result - x
return y - y2
op1 = jax.random.normal(jax.random.key(0), (5,), dtype='bfloat16')
op2 = jax.random.normal(jax.random.key(1), (5,), dtype='bfloat16')
print(check_err(op1, op2))
# [0 -0.00244141 0 0.000488281 0.00390625]
print(jax.jit(check_err)(op1, op2))
# [0 0 0 0 0]
Since it looks like the compiler is doing something unexpected here, it will help to print the optimized HLO for the function:
print(jax.jit(check_err).lower(op1, op2).compile().as_text())
HloModule jit_check_err, entry_computation_layout={(bf16[5]{0}, bf16[5]{0})->bf16[5]{0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}
%fused_computation (param_0.2: bf16[5], param_1.4: bf16[5]) -> bf16[5] {
%param_1.4 = bf16[5]{0} parameter(1)
%convert.11 = f32[5]{0} convert(bf16[5]{0} %param_1.4)
%param_0.2 = bf16[5]{0} parameter(0)
%convert.10 = f32[5]{0} convert(bf16[5]{0} %param_0.2)
%add.0 = f32[5]{0} add(f32[5]{0} %convert.10, f32[5]{0} %convert.11), metadata={op_name="jit(check_err)/jit(main)/add" source_file="<ipython-input-4-c332ca662f3d>" source_line=4}
%subtract.1 = f32[5]{0} subtract(f32[5]{0} %add.0, f32[5]{0} %convert.10), metadata={op_name="jit(check_err)/jit(main)/sub" source_file="<ipython-input-4-c332ca662f3d>" source_line=5}
%subtract.0 = f32[5]{0} subtract(f32[5]{0} %convert.11, f32[5]{0} %subtract.1), metadata={op_name="jit(check_err)/jit(main)/sub" source_file="<ipython-input-4-c332ca662f3d>" source_line=6}
ROOT %convert.9 = bf16[5]{0} convert(f32[5]{0} %subtract.0)
}
ENTRY %main.6 (Arg_0.1: bf16[5], Arg_1.2: bf16[5]) -> bf16[5] {
%Arg_0.1 = bf16[5]{0} parameter(0)
%Arg_1.2 = bf16[5]{0} parameter(1)
ROOT %fusion = bf16[5]{0} fusion(bf16[5]{0} %Arg_0.1, bf16[5]{0} %Arg_1.2), kind=kLoop, calls=%fused_computation
}
and this shows what the problem is: the line %convert.11 = f32[5]{0} convert(bf16[5]{0} %param_1.4)
is converting the input to float32
before doing all the operations, and then %convert.9 = bf16[5]{0} convert(f32[5]{0} %subtract.0)
converts this back to bfloat16
. Thus the error is accumulating in float32 precision, and then when this small error is cast back to bfloat16
, it is too small to be represented in bfloat16, and so we get zero. Essentially, the JIT-compiled version is effectively doing this:
def check_err(x, y):
x, y = x.astype('float32'), y.astype('float32')
result = x + y
y2 = result - x
return (y - y2).astype('bfloat16')
I'm not aware of any way to prevent the compiler from doing this kind of casting – it's probably due to the fact that the hardware (CPU in my case) does not support native bfloat16 operations. I'll ask around to see if others have ideas.
Via @apaszke, it seems the xla_allow_excess_precision
flag controls this behavior. If you set it to False, then the compiler won't do this sort of internal upcasting:
import os
os.environ['XLA_FLAGS'] = "--xla_allow_excess_precision=false"
import jax
def check_err(x, y):
result = x + y
y2 = result - x
return y - y2
op1 = jax.random.normal(jax.random.key(0), (5,), dtype='bfloat16')
op2 = jax.random.normal(jax.random.key(1), (5,), dtype='bfloat16')
print(check_err(op1, op2))
# [0 -0.00244141 0 0.000488281 0.00390625]
print(jax.jit(check_err)(op1, op2))
# [0 -0.00244141 0 0.000488281 0.00390625]
Note that XLA flag values are only read at the time the backend is initialized, so be sure to set them either as a system variable outside your script, or in your script via os.environ
before running any jax
commands.
That seems to work. Thanks!
Description
I have some stochastic rounding code and uncovered a bug when trying to use the code like the following:
With bfloat16, the final line prints
True
even though it's clear from the preceding line that not all errors ought to be 0.np.float32
does not have this behavior.Here are some lowering and compilation outputs, if that happens to be helpful. First bfloat16 and then float32:
(Originally reported at: https://github.com/jax-ml/ml_dtypes/issues/167)
System info (python version, jaxlib version, accelerator, etc.)