jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.47k stars 2.8k forks source link

Finite precision error calculations always 0 under JIT with bfloat16 #23007

Open colehaus opened 3 months ago

colehaus commented 3 months ago

Description

I have some stochastic rounding code and uncovered a bug when trying to use the code like the following:

def _error(x: ndarray[*Shape, Float], y: ndarray[*Shape, Float], result: ndarray[*Shape, Float]):
    y2 = result - x
    x2 = result - y2
    error_y = y - y2
    error_x = x - x2
    return error_x + error_y

def add(x: ndarray[*Shape, Float], y: ndarray[*Shape, Float]):
    result = x + y
    return _error(x, y, result)
dtype = bfloat16
op1 = jax.random.normal(jax.random.key(0), (1000, 4), dtype=dtype)
op2 = jax.random.normal(jax.random.key(1), (1000, 4), dtype=dtype)
print(jax.vmap(add)(op1, op2))
print(jnp.all(jax.jit(jax.vmap(add))(op1, op2) == 0))

With bfloat16, the final line prints True even though it's clear from the preceding line that not all errors ought to be 0. np.float32 does not have this behavior.

Here are some lowering and compilation outputs, if that happens to be helpful. First bfloat16 and then float32:

dtype = bfloat16
args = (jnp.arange(4, 7, dtype=dtype), jnp.arange(3, dtype=dtype) / 1000)
print(add(*args))
print(jax.jit(add)(*args))
print(jax.jit(add).lower(*args).as_text())
print(jax.jit(add).lower(*args).compile().as_text())
[0 0.000999451 0.0019989]
[0 0 0]
module @jit_add attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xbf16> {mhlo.layout_mode = "default"}, %arg1: tensor<3xbf16> {mhlo.layout_mode = "default"}) -> (tensor<3xbf16> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<3xbf16>
    %1 = stablehlo.subtract %0, %arg0 : tensor<3xbf16>
    %2 = stablehlo.subtract %0, %1 : tensor<3xbf16>
    %3 = stablehlo.subtract %arg1, %1 : tensor<3xbf16>
    %4 = stablehlo.subtract %arg0, %2 : tensor<3xbf16>
    %5 = stablehlo.add %4, %3 : tensor<3xbf16>
    return %5 : tensor<3xbf16>
  }
}

HloModule jit_add, is_scheduled=true, entry_computation_layout={(bf16[3]{0}, bf16[3]{0})->bf16[3]{0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="82694c3355091a0097f584dec86f3d57"}

%fused_convert (param_0.3: bf16[3], param_1.5: bf16[3]) -> bf16[3] {
  %param_0.3 = bf16[3]{0} parameter(0)
  %convert.9.1 = f32[3]{0} convert(bf16[3]{0} %param_0.3)
  %param_1.5 = bf16[3]{0} parameter(1)
  %convert.1.1 = f32[3]{0} convert(bf16[3]{0} %param_1.5)
  %add.2.1 = f32[3]{0} add(f32[3]{0} %convert.9.1, f32[3]{0} %convert.1.1), metadata={op_name="jit(add)/jit(main)/add" source_file="/tmp/ipykernel_47449/771070407.py" source_line=787}
  %subtract.8.1 = f32[3]{0} subtract(f32[3]{0} %add.2.1, f32[3]{0} %convert.9.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=779}
  %subtract.9.1 = f32[3]{0} subtract(f32[3]{0} %add.2.1, f32[3]{0} %subtract.8.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=780}
  %subtract.10.1 = f32[3]{0} subtract(f32[3]{0} %convert.9.1, f32[3]{0} %subtract.9.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=782}
  %subtract.11.1 = f32[3]{0} subtract(f32[3]{0} %convert.1.1, f32[3]{0} %subtract.8.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=781}
  %add.4.1 = f32[3]{0} add(f32[3]{0} %subtract.10.1, f32[3]{0} %subtract.11.1), metadata={op_name="jit(add)/jit(main)/add" source_file="/tmp/ipykernel_47449/771070407.py" source_line=783}
  ROOT %convert.17.1 = bf16[3]{0} convert(f32[3]{0} %add.4.1)
}

ENTRY %main.9 (Arg_0.1.0: bf16[3], Arg_1.2.0: bf16[3]) -> bf16[3] {
  %Arg_1.2.0 = bf16[3]{0} parameter(1), metadata={op_name="y"}
  %Arg_0.1.0 = bf16[3]{0} parameter(0), metadata={op_name="x"}
  ROOT %loop_convert_fusion = bf16[3]{0} fusion(bf16[3]{0} %Arg_0.1.0, bf16[3]{0} %Arg_1.2.0), kind=kLoop, calls=%fused_convert
}
dtype = np.float32
args = (jnp.arange(4, 7, dtype=dtype), jnp.arange(3, dtype=dtype) / 1000)
print(add(*args))
print(jax.jit(add)(*args))
print(jax.jit(add).lower(*args).as_text())
print(jax.jit(add).lower(*args).compile().as_text())
[0.0000000e+00 7.2526745e-08 1.4505349e-07]
[0.0000000e+00 7.2526745e-08 1.4505349e-07]
module @jit_add attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3xf32> {mhlo.layout_mode = "default"}) -> (tensor<3xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
    %1 = stablehlo.subtract %0, %arg0 : tensor<3xf32>
    %2 = stablehlo.subtract %0, %1 : tensor<3xf32>
    %3 = stablehlo.subtract %arg1, %1 : tensor<3xf32>
    %4 = stablehlo.subtract %arg0, %2 : tensor<3xf32>
    %5 = stablehlo.add %4, %3 : tensor<3xf32>
    return %5 : tensor<3xf32>
  }
}

HloModule jit_add, is_scheduled=true, entry_computation_layout={(f32[3]{0}, f32[3]{0})->f32[3]{0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="9d06ea6507421c8754deb14690ff8cd9"}

%fused_add (param_0.1: f32[3], param_1.3: f32[3]) -> f32[3] {
  %param_1.3 = f32[3]{0} parameter(1)
  %param_0.1 = f32[3]{0} parameter(0)
  %add.2.1 = f32[3]{0} add(f32[3]{0} %param_1.3, f32[3]{0} %param_0.1), metadata={op_name="jit(add)/jit(main)/add" source_file="/tmp/ipykernel_47449/771070407.py" source_line=787}
  %subtract.8.1 = f32[3]{0} subtract(f32[3]{0} %add.2.1, f32[3]{0} %param_1.3), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=779}
  %subtract.9.1 = f32[3]{0} subtract(f32[3]{0} %add.2.1, f32[3]{0} %subtract.8.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=780}
  %subtract.10.1 = f32[3]{0} subtract(f32[3]{0} %param_1.3, f32[3]{0} %subtract.9.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=782}
  %subtract.11.1 = f32[3]{0} subtract(f32[3]{0} %param_0.1, f32[3]{0} %subtract.8.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=781}
  ROOT %add.4.1 = f32[3]{0} add(f32[3]{0} %subtract.10.1, f32[3]{0} %subtract.11.1), metadata={op_name="jit(add)/jit(main)/add" source_file="/tmp/ipykernel_47449/771070407.py" source_line=783}
}

ENTRY %main.9 (Arg_0.1.0: f32[3], Arg_1.2.0: f32[3]) -> f32[3] {
  %Arg_1.2.0 = f32[3]{0} parameter(1), metadata={op_name="y"}
  %Arg_0.1.0 = f32[3]{0} parameter(0), metadata={op_name="x"}
  ROOT %loop_add_fusion = f32[3]{0} fusion(f32[3]{0} %Arg_1.2.0, f32[3]{0} %Arg_0.1.0), kind=kLoop, calls=%fused_add, metadata={op_name="jit(add)/jit(main)/add" source_file="/tmp/ipykernel_47449/771070407.py" source_line=783}
}

(Originally reported at: https://github.com/jax-ml/ml_dtypes/issues/167)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.26.1
python: 3.11.9 (main, Apr  6 2024, 17:59:24) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='npjfe11cq9', release='5.19.0-45-generic', version='#46~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Jun 7 15:06:04 UTC 20', machine='x86_64')

$ nvidia-smi
Sun Aug 11 01:03:45 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   42C    P8    27W [/](https://file+.vscode-resource.vscode-cdn.net/) 300W |  36785MiB [/](https://file+.vscode-resource.vscode-cdn.net/) 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+
jakevdp commented 3 months ago

Hi - thanks for the question! I spent some time making a more concise reproduction here

import jax

def check_err(x, y):
  result = x + y
  y2 = result - x
  return y - y2

op1 = jax.random.normal(jax.random.key(0), (5,), dtype='bfloat16')
op2 = jax.random.normal(jax.random.key(1), (5,), dtype='bfloat16')

print(check_err(op1, op2))
# [0 -0.00244141 0 0.000488281 0.00390625]

print(jax.jit(check_err)(op1, op2))
# [0 0 0 0 0]

Since it looks like the compiler is doing something unexpected here, it will help to print the optimized HLO for the function:

print(jax.jit(check_err).lower(op1, op2).compile().as_text())
HloModule jit_check_err, entry_computation_layout={(bf16[5]{0}, bf16[5]{0})->bf16[5]{0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}

%fused_computation (param_0.2: bf16[5], param_1.4: bf16[5]) -> bf16[5] {
  %param_1.4 = bf16[5]{0} parameter(1)
  %convert.11 = f32[5]{0} convert(bf16[5]{0} %param_1.4)
  %param_0.2 = bf16[5]{0} parameter(0)
  %convert.10 = f32[5]{0} convert(bf16[5]{0} %param_0.2)
  %add.0 = f32[5]{0} add(f32[5]{0} %convert.10, f32[5]{0} %convert.11), metadata={op_name="jit(check_err)/jit(main)/add" source_file="<ipython-input-4-c332ca662f3d>" source_line=4}
  %subtract.1 = f32[5]{0} subtract(f32[5]{0} %add.0, f32[5]{0} %convert.10), metadata={op_name="jit(check_err)/jit(main)/sub" source_file="<ipython-input-4-c332ca662f3d>" source_line=5}
  %subtract.0 = f32[5]{0} subtract(f32[5]{0} %convert.11, f32[5]{0} %subtract.1), metadata={op_name="jit(check_err)/jit(main)/sub" source_file="<ipython-input-4-c332ca662f3d>" source_line=6}
  ROOT %convert.9 = bf16[5]{0} convert(f32[5]{0} %subtract.0)
}

ENTRY %main.6 (Arg_0.1: bf16[5], Arg_1.2: bf16[5]) -> bf16[5] {
  %Arg_0.1 = bf16[5]{0} parameter(0)
  %Arg_1.2 = bf16[5]{0} parameter(1)
  ROOT %fusion = bf16[5]{0} fusion(bf16[5]{0} %Arg_0.1, bf16[5]{0} %Arg_1.2), kind=kLoop, calls=%fused_computation
}

and this shows what the problem is: the line %convert.11 = f32[5]{0} convert(bf16[5]{0} %param_1.4) is converting the input to float32 before doing all the operations, and then %convert.9 = bf16[5]{0} convert(f32[5]{0} %subtract.0) converts this back to bfloat16. Thus the error is accumulating in float32 precision, and then when this small error is cast back to bfloat16, it is too small to be represented in bfloat16, and so we get zero. Essentially, the JIT-compiled version is effectively doing this:

def check_err(x, y):
  x, y = x.astype('float32'), y.astype('float32')
  result = x + y
  y2 = result - x
  return (y - y2).astype('bfloat16')

I'm not aware of any way to prevent the compiler from doing this kind of casting – it's probably due to the fact that the hardware (CPU in my case) does not support native bfloat16 operations. I'll ask around to see if others have ideas.

jakevdp commented 3 months ago

Via @apaszke, it seems the xla_allow_excess_precision flag controls this behavior. If you set it to False, then the compiler won't do this sort of internal upcasting:

import os
os.environ['XLA_FLAGS'] = "--xla_allow_excess_precision=false"

import jax

def check_err(x, y):
  result = x + y
  y2 = result - x
  return y - y2

op1 = jax.random.normal(jax.random.key(0), (5,), dtype='bfloat16')
op2 = jax.random.normal(jax.random.key(1), (5,), dtype='bfloat16')

print(check_err(op1, op2))
# [0 -0.00244141 0 0.000488281 0.00390625]

print(jax.jit(check_err)(op1, op2))
# [0 -0.00244141 0 0.000488281 0.00390625]

Note that XLA flag values are only read at the time the backend is initialized, so be sure to set them either as a system variable outside your script, or in your script via os.environ before running any jax commands.

colehaus commented 3 months ago

That seems to work. Thanks!