jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
191 stars 25 forks source link

Finite precision error calculations always 0 under JIT with bfloat16 #167

Closed colehaus closed 1 month ago

colehaus commented 1 month ago

I have some stochastic rounding code and uncovered a bug when trying to use the code like the following:

def _error(x: ndarray[*Shape, Float], y: ndarray[*Shape, Float], result: ndarray[*Shape, Float]):
    y2 = result - x
    x2 = result - y2
    error_y = y - y2
    error_x = x - x2
    return error_x + error_y

def add(x: ndarray[*Shape, Float], y: ndarray[*Shape, Float]):
    result = x + y
    return _error(x, y, result)
dtype = bfloat16
op1 = jax.random.normal(jax.random.key(0), (1000, 4), dtype=dtype)
op2 = jax.random.normal(jax.random.key(1), (1000, 4), dtype=dtype)
print(jax.vmap(add)(op1, op2))
print(jnp.all(jax.jit(jax.vmap(add))(op1, op2) == 0))

With bfloat16, the final line prints True even though it's clear from the preceding line that not all errors ought to be 0. np.float32 does not have this behavior.

Here are some lowering and compilation outputs, if that happens to be helpful. First bfloat16 and then float32:

dtype = bfloat16
args = (jnp.arange(4, 7, dtype=dtype), jnp.arange(3, dtype=dtype) / 1000)
print(add(*args))
print(jax.jit(add)(*args))
print(jax.jit(add).lower(*args).as_text())
print(jax.jit(add).lower(*args).compile().as_text())
[0 0.000999451 0.0019989]
[0 0 0]
module @jit_add attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xbf16> {mhlo.layout_mode = "default"}, %arg1: tensor<3xbf16> {mhlo.layout_mode = "default"}) -> (tensor<3xbf16> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<3xbf16>
    %1 = stablehlo.subtract %0, %arg0 : tensor<3xbf16>
    %2 = stablehlo.subtract %0, %1 : tensor<3xbf16>
    %3 = stablehlo.subtract %arg1, %1 : tensor<3xbf16>
    %4 = stablehlo.subtract %arg0, %2 : tensor<3xbf16>
    %5 = stablehlo.add %4, %3 : tensor<3xbf16>
    return %5 : tensor<3xbf16>
  }
}

HloModule jit_add, is_scheduled=true, entry_computation_layout={(bf16[3]{0}, bf16[3]{0})->bf16[3]{0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="82694c3355091a0097f584dec86f3d57"}

%fused_convert (param_0.3: bf16[3], param_1.5: bf16[3]) -> bf16[3] {
  %param_0.3 = bf16[3]{0} parameter(0)
  %convert.9.1 = f32[3]{0} convert(bf16[3]{0} %param_0.3)
  %param_1.5 = bf16[3]{0} parameter(1)
  %convert.1.1 = f32[3]{0} convert(bf16[3]{0} %param_1.5)
  %add.2.1 = f32[3]{0} add(f32[3]{0} %convert.9.1, f32[3]{0} %convert.1.1), metadata={op_name="jit(add)/jit(main)/add" source_file="/tmp/ipykernel_47449/771070407.py" source_line=787}
  %subtract.8.1 = f32[3]{0} subtract(f32[3]{0} %add.2.1, f32[3]{0} %convert.9.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=779}
  %subtract.9.1 = f32[3]{0} subtract(f32[3]{0} %add.2.1, f32[3]{0} %subtract.8.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=780}
  %subtract.10.1 = f32[3]{0} subtract(f32[3]{0} %convert.9.1, f32[3]{0} %subtract.9.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=782}
  %subtract.11.1 = f32[3]{0} subtract(f32[3]{0} %convert.1.1, f32[3]{0} %subtract.8.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=781}
  %add.4.1 = f32[3]{0} add(f32[3]{0} %subtract.10.1, f32[3]{0} %subtract.11.1), metadata={op_name="jit(add)/jit(main)/add" source_file="/tmp/ipykernel_47449/771070407.py" source_line=783}
  ROOT %convert.17.1 = bf16[3]{0} convert(f32[3]{0} %add.4.1)
}

ENTRY %main.9 (Arg_0.1.0: bf16[3], Arg_1.2.0: bf16[3]) -> bf16[3] {
  %Arg_1.2.0 = bf16[3]{0} parameter(1), metadata={op_name="y"}
  %Arg_0.1.0 = bf16[3]{0} parameter(0), metadata={op_name="x"}
  ROOT %loop_convert_fusion = bf16[3]{0} fusion(bf16[3]{0} %Arg_0.1.0, bf16[3]{0} %Arg_1.2.0), kind=kLoop, calls=%fused_convert
}
dtype = np.float32
args = (jnp.arange(4, 7, dtype=dtype), jnp.arange(3, dtype=dtype) / 1000)
print(add(*args))
print(jax.jit(add)(*args))
print(jax.jit(add).lower(*args).as_text())
print(jax.jit(add).lower(*args).compile().as_text())
[0.0000000e+00 7.2526745e-08 1.4505349e-07]
[0.0000000e+00 7.2526745e-08 1.4505349e-07]
module @jit_add attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3xf32> {mhlo.layout_mode = "default"}) -> (tensor<3xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
    %1 = stablehlo.subtract %0, %arg0 : tensor<3xf32>
    %2 = stablehlo.subtract %0, %1 : tensor<3xf32>
    %3 = stablehlo.subtract %arg1, %1 : tensor<3xf32>
    %4 = stablehlo.subtract %arg0, %2 : tensor<3xf32>
    %5 = stablehlo.add %4, %3 : tensor<3xf32>
    return %5 : tensor<3xf32>
  }
}

HloModule jit_add, is_scheduled=true, entry_computation_layout={(f32[3]{0}, f32[3]{0})->f32[3]{0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="9d06ea6507421c8754deb14690ff8cd9"}

%fused_add (param_0.1: f32[3], param_1.3: f32[3]) -> f32[3] {
  %param_1.3 = f32[3]{0} parameter(1)
  %param_0.1 = f32[3]{0} parameter(0)
  %add.2.1 = f32[3]{0} add(f32[3]{0} %param_1.3, f32[3]{0} %param_0.1), metadata={op_name="jit(add)/jit(main)/add" source_file="/tmp/ipykernel_47449/771070407.py" source_line=787}
  %subtract.8.1 = f32[3]{0} subtract(f32[3]{0} %add.2.1, f32[3]{0} %param_1.3), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=779}
  %subtract.9.1 = f32[3]{0} subtract(f32[3]{0} %add.2.1, f32[3]{0} %subtract.8.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=780}
  %subtract.10.1 = f32[3]{0} subtract(f32[3]{0} %param_1.3, f32[3]{0} %subtract.9.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=782}
  %subtract.11.1 = f32[3]{0} subtract(f32[3]{0} %param_0.1, f32[3]{0} %subtract.8.1), metadata={op_name="jit(add)/jit(main)/sub" source_file="/tmp/ipykernel_47449/771070407.py" source_line=781}
  ROOT %add.4.1 = f32[3]{0} add(f32[3]{0} %subtract.10.1, f32[3]{0} %subtract.11.1), metadata={op_name="jit(add)/jit(main)/add" source_file="/tmp/ipykernel_47449/771070407.py" source_line=783}
}

ENTRY %main.9 (Arg_0.1.0: f32[3], Arg_1.2.0: f32[3]) -> f32[3] {
  %Arg_1.2.0 = f32[3]{0} parameter(1), metadata={op_name="y"}
  %Arg_0.1.0 = f32[3]{0} parameter(0), metadata={op_name="x"}
  ROOT %loop_add_fusion = f32[3]{0} fusion(f32[3]{0} %Arg_1.2.0, f32[3]{0} %Arg_0.1.0), kind=kLoop, calls=%fused_add, metadata={op_name="jit(add)/jit(main)/add" source_file="/tmp/ipykernel_47449/771070407.py" source_line=783}
}

Here's the info from jax.print_environment_info():

jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.26.1
python: 3.11.9 (main, Apr  6 2024, 17:59:24) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='npjfe11cq9', release='5.19.0-45-generic', version='#46~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Jun 7 15:06:04 UTC 20', machine='x86_64')

$ nvidia-smi
Sun Aug 11 01:03:45 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   42C    P8    27W [/](https://file+.vscode-resource.vscode-cdn.net/) 300W |  36785MiB [/](https://file+.vscode-resource.vscode-cdn.net/) 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

(Let me know if this is a better fit for the main JAX repo.)

jakevdp commented 1 month ago

Hi - thanks for the report! This looks like a JAX issue, and would be better reported at http://github.com/google/jax/. There is nothing that can be done in this repository to affect the behavior of JAX's JIT compiler. Thanks!

colehaus commented 1 month ago

Okay, thanks, I'll move it over there. I wasn't sure if maybe there was some particularity with the bfloat16 implementation that was affecting the optimizations the JIT compiler thought were safe (since the errors don't occur for e.g. float16).