openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.64k stars 418 forks source link

[XLA:GPU] Regression in FP8 matmul scaling fusion #17887

Open balancap opened 1 week ago

balancap commented 1 week ago

Working on large scale LLM FP8 training (project https://github.com/graphcore-research/jax-scalify), I have been documenting how to get optimal fused FP8 matmul code from JAX (see notebook https://github.com/graphcore-research/jax-scalify/blob/main/docs/JAX%20FP8%20matmul%20tutorial.ipynb).

I have noticed a regression in the following piece of code when moving from JAX 0.4.31 to JAX 0.4.32 (0.4.33 as well):

e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max
# "Dequantization" datatype (note: required to be BF16!)
dqt_dtype = jnp.bfloat16

# XLA requires a "dequantize/quantize" pattern to properly support scaled FP8 inputs/outputs. 
def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):
    # Dequantize x and y
    a_dqt = a_fp8.astype(dqt_dtype) * a_scale.astype(dqt_dtype)
    b_dqt = b_fp8.astype(dqt_dtype) * b_scale.astype(dqt_dtype)

    # Do the matmul (NOTE: adding transpose to simplify HLO).
    d_dqt = jax.lax.dot(a_dqt, b_dqt.T)

    # Rescale & clamp to -max/+max FP8 E4M3 values.
    d_dqt = d_dqt * d_scale.astype(dqt_dtype)
    # NOTE: clamping is NOT optional for proper pattern matching!
    d_dqt = jax.lax.clamp(dqt_dtype(-e4m3_max), d_dqt, dqt_dtype(e4m3_max))
    # (Re)Quantize the scaled matmul output.
    return d_dqt.astype(jnp.float8_e4m3fn)

Previously it was generating an optimal fused call:

ENTRY %main.22 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {
  %constant_1 = f32[] constant(1)
  %Arg_4.5.0 = f32[] parameter(4)
  %Arg_3.4.0 = f32[] parameter(3)
  %Arg_2.3.0 = f32[] parameter(2)
  %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)
  %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)
  %cublas-gemm.clone.1.0 = (f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1, /*index=5*/f32[] %Arg_4.5.0), custom_call_target="__cublas$lt$matmul$f8"
  ROOT %get-tuple-element.1 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) %cublas-gemm.clone.1.0), index=0
}

But now the XLA compiler fails to recognize the pattern, and generates:

HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="61aa7688288250bb61c122109f2f0609"}

%gemm_fusion_dot.19_computation (parameter_0: bf16[32,64], parameter_1: bf16[128,64], parameter_2: f32[]) -> bf16[32,128] {
  %parameter_0 = bf16[32,64]{1,0} parameter(0)
  %parameter_1 = bf16[128,64]{1,0} parameter(1)
  %dot.1 = bf16[32,128]{1,0} dot(bf16[32,64]{1,0} %parameter_0, bf16[128,64]{1,0} %parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
  %parameter_2 = f32[] parameter(2)
  %convert.6 = bf16[] convert(f32[] %parameter_2)
  %broadcast.5 = bf16[32,128]{1,0} broadcast(bf16[] %convert.6), dimensions={}
  ROOT %multiply.3 = bf16[32,128]{1,0} multiply(bf16[32,128]{1,0} %dot.1, bf16[32,128]{1,0} %broadcast.5)
}

%fused_convert (param_0.15: bf16[32,128]) -> f8e4m3fn[32,128] {
  %constant_8_2 = bf16[] constant(-448)
  %convert.22.1 = f32[] convert(bf16[] %constant_8_2)
  %broadcast.14.1 = f32[32,128]{1,0} broadcast(f32[] %convert.22.1), dimensions={}
  %param_0.15 = bf16[32,128]{1,0} parameter(0)
  %convert.17.1 = f32[32,128]{1,0} convert(bf16[32,128]{1,0} %param_0.15)
  %constant_6_2 = bf16[] constant(448)
  %convert.23.1 = f32[] convert(bf16[] %constant_6_2)
  %broadcast.17.3 = f32[32,128]{1,0} broadcast(f32[] %convert.23.1), dimensions={}
  %clamp.1.3 = f32[32,128]{1,0} clamp(f32[32,128]{1,0} %broadcast.14.1, f32[32,128]{1,0} %convert.17.1, f32[32,128]{1,0} %broadcast.17.3)
  ROOT %convert.13.1 = f8e4m3fn[32,128]{1,0} convert(f32[32,128]{1,0} %clamp.1.3)
}

%fused_slice (param_0_0: f8e4m3fn[32,64], param_0_1: f32[], param_1_0: f8e4m3fn[128,64], param_1_1: f32[]) -> (bf16[2048], bf16[8192]) {
  %param_0_0 = f8e4m3fn[32,64]{1,0} parameter(0)
  %convert.7.2 = bf16[32,64]{1,0} convert(f8e4m3fn[32,64]{1,0} %param_0_0)
  %param_0_1 = f32[] parameter(1)
  %convert.8.2 = bf16[] convert(f32[] %param_0_1)
  %broadcast.8.2 = bf16[32,64]{1,0} broadcast(bf16[] %convert.8.2), dimensions={}
  %multiply.4.2 = bf16[32,64]{1,0} multiply(bf16[32,64]{1,0} %convert.7.2, bf16[32,64]{1,0} %broadcast.8.2)
  %reshape.2 = bf16[2048]{0} reshape(bf16[32,64]{1,0} %multiply.4.2)
  %param_1_0 = f8e4m3fn[128,64]{1,0} parameter(2)
  %convert.9.2 = bf16[128,64]{1,0} convert(f8e4m3fn[128,64]{1,0} %param_1_0)
  %param_1_1 = f32[] parameter(3)
  %convert.12.2 = bf16[] convert(f32[] %param_1_1)
  %broadcast.10.2 = bf16[128,64]{1,0} broadcast(bf16[] %convert.12.2), dimensions={}
  %multiply.5.2 = bf16[128,64]{1,0} multiply(bf16[128,64]{1,0} %convert.9.2, bf16[128,64]{1,0} %broadcast.10.2)
  %reshape.3 = bf16[8192]{0} reshape(bf16[128,64]{1,0} %multiply.5.2)
  %concatenate = bf16[10240]{0} concatenate(bf16[2048]{0} %reshape.2, bf16[8192]{0} %reshape.3), dimensions={0}
  %slice = bf16[2048]{0} slice(bf16[10240]{0} %concatenate), slice={[0:2048]}
  %slice.1 = bf16[8192]{0} slice(bf16[10240]{0} %concatenate), slice={[2048:10240]}
  ROOT %tuple = (bf16[2048]{0}, bf16[8192]{0}) tuple(bf16[2048]{0} %slice, bf16[8192]{0} %slice.1)
}

ENTRY %main.25 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {
  %Arg_4.5.0 = f32[] parameter(4)
  %Arg_3.4.0 = f32[] parameter(3)
  %Arg_2.3.0 = f32[] parameter(2)
  %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)
  %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)
  %input_slice_fusion = (bf16[2048]{0}, bf16[8192]{0}) fusion(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f32[] %Arg_2.3.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_3.4.0), kind=kInput, calls=%fused_slice
  %get-tuple-element.1 = bf16[8192]{0} get-tuple-element((bf16[2048]{0}, bf16[8192]{0}) %input_slice_fusion), index=1
  %get-tuple-element = bf16[2048]{0} get-tuple-element((bf16[2048]{0}, bf16[8192]{0}) %input_slice_fusion), index=0
  %bitcast.55 = bf16[32,64]{1,0} bitcast(bf16[2048]{0} %get-tuple-element)
  %bitcast.56 = bf16[128,64]{1,0} bitcast(bf16[8192]{0} %get-tuple-element.1)
  %gemm_fusion_dot.19.0 = bf16[32,128]{1,0} fusion(bf16[32,64]{1,0} %bitcast.55, bf16[128,64]{1,0} %bitcast.56, f32[] %Arg_4.5.0), kind=kCustom, calls=%gemm_fusion_dot.19_computation
  ROOT %loop_convert_fusion = f8e4m3fn[32,128]{1,0} fusion(bf16[32,128]{1,0} %gemm_fusion_dot.19.0), kind=kLoop, calls=%fused_convert
}

For both JAX versions, the stableHLO is the same:

module @jit_matmul_fn_with_scale attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<32x64xf8E4M3FN> {mhlo.layout_mode = "default"}, %arg1: tensor<128x64xf8E4M3FN> {mhlo.layout_mode = "default"}, %arg2: tensor<f32> {mhlo.layout_mode = "default"}, %arg3: tensor<f32> {mhlo.layout_mode = "default"}, %arg4: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<32x128xf8E4M3FN> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.convert %arg0 : (tensor<32x64xf8E4M3FN>) -> tensor<32x64xbf16>
    %1 = stablehlo.convert %arg2 : (tensor<f32>) -> tensor<bf16>
    %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<bf16>) -> tensor<32x64xbf16>
    %3 = stablehlo.multiply %0, %2 : tensor<32x64xbf16>
    %4 = stablehlo.convert %arg1 : (tensor<128x64xf8E4M3FN>) -> tensor<128x64xbf16>
    %5 = stablehlo.convert %arg3 : (tensor<f32>) -> tensor<bf16>
    %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor<bf16>) -> tensor<128x64xbf16>
    %7 = stablehlo.multiply %4, %6 : tensor<128x64xbf16>
    %8 = stablehlo.transpose %7, dims = [1, 0] : (tensor<128x64xbf16>) -> tensor<64x128xbf16>
    %9 = stablehlo.dot_general %3, %8, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x64xbf16>, tensor<64x128xbf16>) -> tensor<32x128xbf16>
    %10 = stablehlo.convert %arg4 : (tensor<f32>) -> tensor<bf16>
    %11 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor<bf16>) -> tensor<32x128xbf16>
    %12 = stablehlo.multiply %9, %11 : tensor<32x128xbf16>
    %cst = stablehlo.constant dense<-4.480000e+02> : tensor<bf16>
    %cst_0 = stablehlo.constant dense<4.480000e+02> : tensor<bf16>
    %13 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<bf16>) -> tensor<32x128xbf16>
    %14 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<bf16>) -> tensor<32x128xbf16>
    %15 = stablehlo.clamp %13, %12, %14 : tensor<32x128xbf16>
    %16 = stablehlo.convert %15 : (tensor<32x128xbf16>) -> tensor<32x128xf8E4M3FN>
    return %16 : tensor<32x128xf8E4M3FN>
  }
}

Note: the fusion logic is back working when adding relu or/and abs-max capture logic:

# ReLU non-linearity. Note: needs to be before the scaling.
d_dqt = jax.nn.relu(d_dqt)
# Delayed rescaling: capture the raw output scaling for latter.
out_scale = jnp.max(jnp.abs(d_dqt)).astype(jnp.float32)

(additionally: I can't seem to get gelu fusion working?)

JAX bug ticket: https://github.com/jax-ml/jax/issues/24051