intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
131 stars 39 forks source link

[language/test_core.py::test_precise_math] `tl.math.div_rn(x,y)` fails on XPU #601

Closed whitneywhtsang closed 5 months ago

whitneywhtsang commented 7 months ago

On Nvidia, it is lowered to __nv_fdiv_rn, we need to lower to an equivalent call.

tdeng5 commented 7 months ago

Need to refresh the table which contains the NV and Intel GPU math functions' mapping.

Stonepia commented 7 months ago

For collecting the math functions, I formerly added the print line in libdevice.py and run all suites on A100, For example, for this function, I will add the print("clz") for debugging.

https://github.com/intel/intel-xpu-backend-for-triton/blob/llvm-target/python/triton/language/extra/cuda/libdevice.py#L4-L9

If the function is used in the three benchmark suites, that will be the highest priority, I believe that current high priority should have been fully supported.

For the full op list, you could filter the libdevice.py and find all the ops triton needs. Note that there are some naming convention inconsistent between Intel Math library and NV's, but that could be solved on Triton side.

I also attach the old version here. libdevice Ops in Triton (old).xlsx

vlad-penkin commented 7 months ago

@whitneywhtsang can we close this issue?

whitneywhtsang commented 7 months ago

@whitneywhtsang can we close this issue?

The failure still exists, we cannot close this issue.

whitneywhtsang commented 7 months ago

@ESI-SYD Please regenerate the table, and figure out if there is an equivalent function to __nv_fdiv_rn on Intel GPU.

ESI-SYD commented 7 months ago

Regenerate: Updated_libdevice Ops in Triton.xlsx

Find __imf_fdiv_rn() here

ESI-SYD commented 6 months ago

The latest issue of this case, is accuracy failure, seems due to FpToFpOpConversion not support f32-f64 conversion which used in out_ref calculation

whitneywhtsang commented 6 months ago

The latest issue of this case, is accuracy failure, seems due to FpToFpOpConversion not support f32-f64 conversion which used in out_ref calculation

Does NV implementation support f32-f64 conversion? Why does it work for them?

ESI-SYD commented 6 months ago

The current accuracy loss is due to the reference Not using double floating point calculations. This problem can still be seen even using the LLVM O0. A binary search for the O0 problem pass shows that the application of this optimization pass results in the loss of fpext and fptrunc operations.

Before:

  %30 = phi i32 [ %28, %27 ], [ undef, %22 ]
  %31 = bitcast i32 %30 to <1 x float>, !dbg !18
  %32 = extractelement <1 x float> %31, i32 0, !dbg !18
  %33 = call spir_func noundef float @_Z8__fp_divIfET_S0_S0_i(float noundef %25, float noundef %32, i32 noundef 0) #5, !dbg !19
  %34 = fpext float %25 to double, !dbg !20
  %35 = fpext float %32 to double, !dbg !21
  %36 = fdiv double %34, %35, !dbg !22
  %37 = fptrunc double %36 to float, !dbg !23
  %38 = getelementptr float, ptr addrspace(1) %2, i32 %18, !dbg !24
  %39 = call i64 @_Z12get_local_idj(i32 0)
  %40 = insertelement <1 x float> undef, float %33, i32 0, !dbg !25
  %41 = bitcast <1 x float> %40 to i32, !dbg !25
  %42 = insertelement <1 x i32> undef, i32 %41, i32 0, !dbg !25
  br i1 true, label %43, label %44, !dbg !25

After:

define spir_kernel void @kernel_0d1d2d3d(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, ptr addrspace(1) %3) !dbg !11 !intel_reqd_sub_group_size !12 !max_work_group_size !13 {
  %5 = call i64 @_Z12get_local_idj(i32 0)
  %6 = and i64 %5, 127, !dbg !14
  br i1 true, label %7, label %11, !dbg !15

7:                                                ; preds = %4
  %8 = and i64 %5, 127, !dbg !16
  %9 = getelementptr float, ptr addrspace(1) %0, i64 %8, !dbg !16
  %10 = load <1 x float>, ptr addrspace(1) %9, align 4, !dbg !15
  br label %11, !dbg !15

11:                                               ; preds = %7, %4
  %12 = phi <1 x float> [ %10, %7 ], [ poison, %4 ]
  %13 = extractelement <1 x float> %12, i64 0, !dbg !15
  br i1 true, label %14, label %18, !dbg !17

14:                                               ; preds = %11
  %15 = and i64 %5, 127, !dbg !18
  %16 = getelementptr float, ptr addrspace(1) %1, i64 %15, !dbg !18
  %17 = load <1 x float>, ptr addrspace(1) %16, align 4, !dbg !17
  br label %18, !dbg !17

18:                                               ; preds = %14, %11
  %19 = phi <1 x float> [ %17, %14 ], [ poison, %11 ]
  %20 = extractelement <1 x float> %19, i64 0, !dbg !17
  %21 = call spir_func noundef float @_Z8__fp_divIfET_S0_S0_i(float noundef %13, float noundef %20, i32 noundef 0) #5, !dbg !19
  %22 = fdiv float %13, %20, !dbg !20
  %23 = call i64 @_Z12get_local_idj(i32 0)
  br i1 true, label %24, label %27, !dbg !21
etiotto commented 6 months ago

Interesting find @ESI-SYD. I think we may have to mark the kernel with this LLVM function attribute "strictfp" (https://llvm.org/docs/LangRef.html)

ESI-SYD commented 6 months ago

Interesting find @ESI-SYD. I think we may have to mark the kernel with this LLVM function attribute "strictfp" (https://llvm.org/docs/LangRef.html)

Thanks, will do a try.

ESI-SYD commented 6 months ago

@etiotto , Seems strictfp does not able to prevent this unexpected simplification.

; Function Attrs: alwaysinline mustprogress norecurse nounwind strictfp
define spir_kernel void @kernel_0d1d2d3d(ptr addrspace(1) nocapture readonly %0, ptr addrspace(1) nocapture readonly %1, ptr addrspace(1) nocapture writeonly %2, ptr addrspace(1) nocapture writeonly %3) local_unnamed_addr #0 !dbg !11 !max_work_group_size !12 !intel_reqd_sub_group_size !13 {
  %5 = tail call i64 @_Z12get_local_idj(i32 0) #5
  %6 = and i64 %5, 127, !dbg !14
  %7 = getelementptr float, ptr addrspace(1) %0, i64 %6, !dbg !15
  %8 = load float, ptr addrspace(1) %7, align 4, !dbg !16
  %9 = getelementptr float, ptr addrspace(1) %1, i64 %6, !dbg !17
  %10 = load float, ptr addrspace(1) %9, align 4, !dbg !18
  %11 = tail call spir_func noundef float @_Z8__fp_divIfET_S0_S0_i(float noundef %8, float noundef %10, i32 noundef 0) #6, !dbg !19
  %12 = fdiv float %8, %10, !dbg !20
  %13 = getelementptr float, ptr addrspace(1) %2, i64 %6, !dbg !21
  %14 = tail call i64 @_Z12get_local_idj(i32 0) #5
  store float %11, ptr addrspace(1) %13, align 4, !dbg !22
  %15 = getelementptr float, ptr addrspace(1) %3, i64 %6, !dbg !23
  %16 = tail call i64 @_Z12get_local_idj(i32 0) #5
  store float %12, ptr addrspace(1) %15, align 4, !dbg !24
  ret void, !dbg !25
}

Do we accept case modification as a workaround?

diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py
index fd34a245..680997eb 100644
--- a/python/test/unit/language/test_core.py
+++ b/python/test/unit/language/test_core.py
@@ -1011,13 +1011,19 @@ def test_math_divide_op(expr, num_ctas, device):
                           ('tl.math.div_rn(x,y)', '(x.to(tl.float64) / y.to(tl.float64)).to(tl.float32)')])
 @pytest.mark.parametrize("num_ctas", num_ctas_list)
 def test_precise_math(expr_prec, expr_ref, num_ctas, device):
+    if is_hip():
+        pytest.skip("TODO test_precise_math (added by https://github.com/openai/triton/pull/3172) does not work on HIP")

     @triton.jit
-    def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr):
+    def kernel(X, Y, OUT, OUT_REF, CASE: tl.constexpr, BLOCK: tl.constexpr):
         x = tl.load(X + tl.arange(0, BLOCK))
         y = tl.load(Y + tl.arange(0, BLOCK))
         prec = PREC_CALC
         ref = REF_CALC
+        if CASE == 'div_rn':
+            ref_tmp = (x.to(tl.float64) / y.to(tl.float64))
+            tl.device_print("",ref_tmp)
+            ref = ref_tmp.to(tl.float32)
         tl.store(OUT + tl.arange(0, BLOCK), prec)
         tl.store(OUT_REF + tl.arange(0, BLOCK), ref)

@@ -1029,18 +1035,17 @@ def test_precise_math(expr_prec, expr_ref, num_ctas, device):
     y = torch.randn(shape, dtype=torch.float32, device=device)

     if (expr_prec.count('sqrt') > 0):
+        case = 'sqrt'
         x = torch.abs(x)

     if (expr_prec.count('div') > 0):
+        case = 'div_rn'
         y += 1e-6

     kernel = patch_kernel(kernel, {'PREC_CALC': expr_prec, 'REF_CALC': expr_ref})

-    kernel[(1, )](x, y, out, out_ref, BLOCK=shape[0], num_ctas=num_ctas)
+    kernel[(1, )](x, y, out, out_ref, case, BLOCK=shape[0], num_ctas=num_ctas)

-    if is_xpu() and expr_prec == 'tl.math.div_rn(x,y)':
-        np.testing.assert_allclose(to_numpy(out), to_numpy(out_ref), rtol=1e-6)
-        pytest.skip("FIXME: Fail accuracy")
     assert torch.all(out == out_ref)  # bitwise exact
whitneywhtsang commented 6 months ago

Do we accept case modification as a workaround?

I don't think so, as adding a print only avoid the actual problem.

How did you add strictfp? Did you add it to #0?

ESI-SYD commented 6 months ago

Do we accept case modification as a workaround?

I don't think so, as adding a print only avoid the actual problem.

How did you add strictfp? Did you add it to #0?

yes. Upload patch and kernel dumped for details here. #601.zip

whitneywhtsang commented 6 months ago

Suggest to look into the optimization pass that remove the fp conversions to check if it checks the function attribute strictfp, or any other attribute.

ESI-SYD commented 6 months ago

Update after offline sync with @chengjunlu : 1. The case aims to test the math library precision. 2. Analysis:

Why NV pass: NV PTX support both single and double float div NV'FDivOp. And the "(x.to(tl.float64) / y.to(tl.float64)).to(tl.float32)" is lowered correctly to "div.rnd.f64 d, a, b; // IEEE 754 compliant rounding". https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div

Why Intel fails: We don't find the corresponding float64 div on Intel GPU. And we use the approximate float div like the NV one "div.approx{.ftz}.f32 d, a, b; // fast, approximate divide"

3. The test case can pass with of the double-precision-float approximate div. But failed of single-precision-float approximate div. 4. The LLVM optimization pass instCombine always re-write the pattern "fptrunc(OpI (fpextend x), (fpextend y))" to single-preciesion on both Triton side and IGC side.

Running pass: PromotePass on kernel_0d1d2d3d (33 instructions)
; *** IR Dump After PromotePass on kernel_0d1d2d3d ***
; Function Attrs: strictfp
define spir_kernel void @kernel_0d1d2d3d(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, ptr addrspace(1) %3) local_unnamed_addr #0 !dbg !11 !max_work_group_size !12 !intel_reqd_sub_group_size !13 {
  %5 = call i64 @_Z12get_local_idj(i32 0)
  %6 = trunc i64 %5 to i32, !dbg !14
  %7 = urem i32 %6, 32, !dbg !14
  %8 = udiv i32 %6, 32, !dbg !14
  %9 = urem i32 %8, 4, !dbg !14
  %10 = mul nuw nsw i32 %9, 32, !dbg !14
  %11 = add nuw nsw i32 %7, %10, !dbg !14
  %12 = getelementptr float, ptr addrspace(1) %0, i32 %11, !dbg !15
  %13 = load i32, ptr addrspace(1) %12, align 4, !dbg !16
  %14 = bitcast i32 %13 to <1 x float>, !dbg !16
  %15 = extractelement <1 x float> %14, i32 0, !dbg !16
  %16 = getelementptr float, ptr addrspace(1) %1, i32 %11, !dbg !17
  %17 = load i32, ptr addrspace(1) %16, align 4, !dbg !18
  %18 = bitcast i32 %17 to <1 x float>, !dbg !18
  %19 = extractelement <1 x float> %18, i32 0, !dbg !18
  %20 = call spir_func float @__imf_fdiv_rn(float %15, float %19), !dbg !19
  %21 = fpext float %15 to double, !dbg !20
  %22 = fpext float %19 to double, !dbg !21
  %23 = fdiv double %21, %22, !dbg !22
  %24 = fptrunc double %23 to float, !dbg !23
  %25 = getelementptr float, ptr addrspace(1) %2, i32 %11, !dbg !24
  %26 = call i64 @_Z12get_local_idj(i32 0)
  %27 = insertelement <1 x float> undef, float %20, i32 0, !dbg !25
  %28 = bitcast <1 x float> %27 to i32, !dbg !25
  %29 = insertelement <1 x i32> undef, i32 %28, i32 0, !dbg !25
  store <1 x i32> %29, ptr addrspace(1) %25, align 4, !dbg !25
  %30 = getelementptr float, ptr addrspace(1) %3, i32 %11, !dbg !26
  %31 = call i64 @_Z12get_local_idj(i32 0)
  %32 = insertelement <1 x float> undef, float %24, i32 0, !dbg !27
  %33 = bitcast <1 x float> %32 to i32, !dbg !27
  %34 = insertelement <1 x i32> undef, i32 %33, i32 0, !dbg !27
  store <1 x i32> %34, ptr addrspace(1) %30, align 4, !dbg !27
  ret void, !dbg !28
}
Running pass: InstCombinePass on kernel_0d1d2d3d (33 instructions)
Running analysis: OptimizationRemarkEmitterAnalysis on kernel_0d1d2d3d
Running analysis: AAManager on kernel_0d1d2d3d
Running analysis: BasicAA on kernel_0d1d2d3d
Running analysis: ScopedNoAliasAA on kernel_0d1d2d3d
Running analysis: TypeBasedAA on kernel_0d1d2d3d
Running analysis: OuterAnalysisManagerProxy<llvm::AnalysisManager<llvm::Module>, llvm::Function> on kernel_0d1d2d3d
; *** IR Dump After InstCombinePass on kernel_0d1d2d3d ***
; Function Attrs: strictfp
define spir_kernel void @kernel_0d1d2d3d(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, ptr addrspace(1) %3) local_unnamed_addr #0 !dbg !11 !max_work_group_size !12 !intel_reqd_sub_group_size !13 {
  %5 = call i64 @_Z12get_local_idj(i32 0)
  %6 = and i64 %5, 127, !dbg !14
  %7 = and i64 %5, 127, !dbg !15
  %8 = getelementptr float, ptr addrspace(1) %0, i64 %7, !dbg !15
  %9 = load <1 x float>, ptr addrspace(1) %8, align 4, !dbg !16
  %10 = extractelement <1 x float> %9, i64 0, !dbg !16
  %11 = and i64 %5, 127, !dbg !17
  %12 = getelementptr float, ptr addrspace(1) %1, i64 %11, !dbg !17
  %13 = load <1 x float>, ptr addrspace(1) %12, align 4, !dbg !18
  %14 = extractelement <1 x float> %13, i64 0, !dbg !18
  %15 = call spir_func float @__imf_fdiv_rn(float %10, float %14), !dbg !19
  %16 = fdiv float %10, %14, !dbg !20
  %17 = and i64 %5, 127, !dbg !21
  %18 = getelementptr float, ptr addrspace(1) %2, i64 %17, !dbg !21
  %19 = call i64 @_Z12get_local_idj(i32 0)
  store float %15, ptr addrspace(1) %18, align 4, !dbg !22
  %20 = getelementptr float, ptr addrspace(1) %3, i64 %6, !dbg !23
  %21 = call i64 @_Z12get_local_idj(i32 0)
  store float %16, ptr addrspace(1) %20, align 4, !dbg !24
  ret void, !dbg !25
}

What we can do: 1. The test case itself focus on the precision of math library. 2. Support more div in the spec of Triton language on the aspect of the floating mathematic operation.

Workarounds:

1. Use cpu result as reference to pass 2. bypass the pattern "fptrunc(OpI (fpextend x), (fpextend y))"

tdeng5 commented 6 months ago

after discussed, some proposal for next step.

  1. File a LLVM GitHub issue to disable the following optimization when function is marked with attribute like strictfp. The LLVM optimization pass instCombine always re-write the pattern "fptrunc(OpI (fpextend x), (fpextend y))" to single-preciesion on both Triton side and IGC side.
  2. Change the test case to use CPU's result as reference. This CI need to upstream to OpenAI repo in the future, since the problem also encountered by AMD.
ESI-SYD commented 6 months ago

Issue: https://github.com/llvm/llvm-project/issues/88222