Closed whitneywhtsang closed 5 months ago
Need to refresh the table which contains the NV and Intel GPU math functions' mapping.
For collecting the math functions, I formerly added the print line in libdevice.py
and run all suites on A100, For example, for this function, I will add the print("clz")
for debugging.
If the function is used in the three benchmark suites, that will be the highest priority, I believe that current high priority should have been fully supported.
For the full op list, you could filter the libdevice.py
and find all the ops triton needs. Note that there are some naming convention inconsistent between Intel Math library and NV's, but that could be solved on Triton side.
I also attach the old version here. libdevice Ops in Triton (old).xlsx
@whitneywhtsang can we close this issue?
@whitneywhtsang can we close this issue?
The failure still exists, we cannot close this issue.
@ESI-SYD Please regenerate the table, and figure out if there is an equivalent function to __nv_fdiv_rn
on Intel GPU.
Regenerate: Updated_libdevice Ops in Triton.xlsx
Find __imf_fdiv_rn()
here
The latest issue of this case, is accuracy failure, seems due to FpToFpOpConversion not support f32-f64 conversion which used in out_ref
calculation
The latest issue of this case, is accuracy failure, seems due to FpToFpOpConversion not support f32-f64 conversion which used in
out_ref
calculation
Does NV implementation support f32-f64 conversion? Why does it work for them?
The current accuracy loss is due to the reference Not using double
floating point calculations. This problem can still be seen even using the LLVM O0
. A binary search for the O0
problem pass shows that the application of this optimization pass results in the loss of fpext
and fptrunc
operations.
Before:
%30 = phi i32 [ %28, %27 ], [ undef, %22 ]
%31 = bitcast i32 %30 to <1 x float>, !dbg !18
%32 = extractelement <1 x float> %31, i32 0, !dbg !18
%33 = call spir_func noundef float @_Z8__fp_divIfET_S0_S0_i(float noundef %25, float noundef %32, i32 noundef 0) #5, !dbg !19
%34 = fpext float %25 to double, !dbg !20
%35 = fpext float %32 to double, !dbg !21
%36 = fdiv double %34, %35, !dbg !22
%37 = fptrunc double %36 to float, !dbg !23
%38 = getelementptr float, ptr addrspace(1) %2, i32 %18, !dbg !24
%39 = call i64 @_Z12get_local_idj(i32 0)
%40 = insertelement <1 x float> undef, float %33, i32 0, !dbg !25
%41 = bitcast <1 x float> %40 to i32, !dbg !25
%42 = insertelement <1 x i32> undef, i32 %41, i32 0, !dbg !25
br i1 true, label %43, label %44, !dbg !25
After:
define spir_kernel void @kernel_0d1d2d3d(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, ptr addrspace(1) %3) !dbg !11 !intel_reqd_sub_group_size !12 !max_work_group_size !13 {
%5 = call i64 @_Z12get_local_idj(i32 0)
%6 = and i64 %5, 127, !dbg !14
br i1 true, label %7, label %11, !dbg !15
7: ; preds = %4
%8 = and i64 %5, 127, !dbg !16
%9 = getelementptr float, ptr addrspace(1) %0, i64 %8, !dbg !16
%10 = load <1 x float>, ptr addrspace(1) %9, align 4, !dbg !15
br label %11, !dbg !15
11: ; preds = %7, %4
%12 = phi <1 x float> [ %10, %7 ], [ poison, %4 ]
%13 = extractelement <1 x float> %12, i64 0, !dbg !15
br i1 true, label %14, label %18, !dbg !17
14: ; preds = %11
%15 = and i64 %5, 127, !dbg !18
%16 = getelementptr float, ptr addrspace(1) %1, i64 %15, !dbg !18
%17 = load <1 x float>, ptr addrspace(1) %16, align 4, !dbg !17
br label %18, !dbg !17
18: ; preds = %14, %11
%19 = phi <1 x float> [ %17, %14 ], [ poison, %11 ]
%20 = extractelement <1 x float> %19, i64 0, !dbg !17
%21 = call spir_func noundef float @_Z8__fp_divIfET_S0_S0_i(float noundef %13, float noundef %20, i32 noundef 0) #5, !dbg !19
%22 = fdiv float %13, %20, !dbg !20
%23 = call i64 @_Z12get_local_idj(i32 0)
br i1 true, label %24, label %27, !dbg !21
Interesting find @ESI-SYD. I think we may have to mark the kernel with this LLVM function attribute "strictfp" (https://llvm.org/docs/LangRef.html)
Interesting find @ESI-SYD. I think we may have to mark the kernel with this LLVM function attribute "strictfp" (https://llvm.org/docs/LangRef.html)
Thanks, will do a try.
@etiotto , Seems strictfp
does not able to prevent this unexpected simplification.
; Function Attrs: alwaysinline mustprogress norecurse nounwind strictfp
define spir_kernel void @kernel_0d1d2d3d(ptr addrspace(1) nocapture readonly %0, ptr addrspace(1) nocapture readonly %1, ptr addrspace(1) nocapture writeonly %2, ptr addrspace(1) nocapture writeonly %3) local_unnamed_addr #0 !dbg !11 !max_work_group_size !12 !intel_reqd_sub_group_size !13 {
%5 = tail call i64 @_Z12get_local_idj(i32 0) #5
%6 = and i64 %5, 127, !dbg !14
%7 = getelementptr float, ptr addrspace(1) %0, i64 %6, !dbg !15
%8 = load float, ptr addrspace(1) %7, align 4, !dbg !16
%9 = getelementptr float, ptr addrspace(1) %1, i64 %6, !dbg !17
%10 = load float, ptr addrspace(1) %9, align 4, !dbg !18
%11 = tail call spir_func noundef float @_Z8__fp_divIfET_S0_S0_i(float noundef %8, float noundef %10, i32 noundef 0) #6, !dbg !19
%12 = fdiv float %8, %10, !dbg !20
%13 = getelementptr float, ptr addrspace(1) %2, i64 %6, !dbg !21
%14 = tail call i64 @_Z12get_local_idj(i32 0) #5
store float %11, ptr addrspace(1) %13, align 4, !dbg !22
%15 = getelementptr float, ptr addrspace(1) %3, i64 %6, !dbg !23
%16 = tail call i64 @_Z12get_local_idj(i32 0) #5
store float %12, ptr addrspace(1) %15, align 4, !dbg !24
ret void, !dbg !25
}
Do we accept case modification as a workaround?
diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py
index fd34a245..680997eb 100644
--- a/python/test/unit/language/test_core.py
+++ b/python/test/unit/language/test_core.py
@@ -1011,13 +1011,19 @@ def test_math_divide_op(expr, num_ctas, device):
('tl.math.div_rn(x,y)', '(x.to(tl.float64) / y.to(tl.float64)).to(tl.float32)')])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_precise_math(expr_prec, expr_ref, num_ctas, device):
+ if is_hip():
+ pytest.skip("TODO test_precise_math (added by https://github.com/openai/triton/pull/3172) does not work on HIP")
@triton.jit
- def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr):
+ def kernel(X, Y, OUT, OUT_REF, CASE: tl.constexpr, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.load(Y + tl.arange(0, BLOCK))
prec = PREC_CALC
ref = REF_CALC
+ if CASE == 'div_rn':
+ ref_tmp = (x.to(tl.float64) / y.to(tl.float64))
+ tl.device_print("",ref_tmp)
+ ref = ref_tmp.to(tl.float32)
tl.store(OUT + tl.arange(0, BLOCK), prec)
tl.store(OUT_REF + tl.arange(0, BLOCK), ref)
@@ -1029,18 +1035,17 @@ def test_precise_math(expr_prec, expr_ref, num_ctas, device):
y = torch.randn(shape, dtype=torch.float32, device=device)
if (expr_prec.count('sqrt') > 0):
+ case = 'sqrt'
x = torch.abs(x)
if (expr_prec.count('div') > 0):
+ case = 'div_rn'
y += 1e-6
kernel = patch_kernel(kernel, {'PREC_CALC': expr_prec, 'REF_CALC': expr_ref})
- kernel[(1, )](x, y, out, out_ref, BLOCK=shape[0], num_ctas=num_ctas)
+ kernel[(1, )](x, y, out, out_ref, case, BLOCK=shape[0], num_ctas=num_ctas)
- if is_xpu() and expr_prec == 'tl.math.div_rn(x,y)':
- np.testing.assert_allclose(to_numpy(out), to_numpy(out_ref), rtol=1e-6)
- pytest.skip("FIXME: Fail accuracy")
assert torch.all(out == out_ref) # bitwise exact
Do we accept case modification as a workaround?
I don't think so, as adding a print only avoid the actual problem.
How did you add strictfp
? Did you add it to #0
?
Do we accept case modification as a workaround?
I don't think so, as adding a print only avoid the actual problem.
How did you add
strictfp
? Did you add it to#0
?
yes. Upload patch and kernel dumped for details here. #601.zip
Suggest to look into the optimization pass that remove the fp conversions to check if it checks the function attribute strictfp
, or any other attribute.
Update after offline sync with @chengjunlu : 1. The case aims to test the math library precision. 2. Analysis:
Why NV pass:
NV PTX support both single and double float div NV'FDivOp. And the "(x.to(tl.float64) / y.to(tl.float64)).to(tl.float32)"
is lowered correctly to "div.rnd.f64 d, a, b; // IEEE 754 compliant rounding".
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div
Why Intel fails:
We don't find the corresponding float64 div on Intel GPU. And we use the approximate float div like the NV one "div.approx{.ftz}.f32 d, a, b; // fast, approximate divide"
3. The test case can pass with of the double-precision-float approximate div. But failed of single-precision-float approximate div.
4. The LLVM optimization pass instCombine always re-write the pattern "fptrunc(OpI (fpextend x), (fpextend y))"
to single-preciesion on both Triton side and IGC side.
Running pass: PromotePass on kernel_0d1d2d3d (33 instructions)
; *** IR Dump After PromotePass on kernel_0d1d2d3d ***
; Function Attrs: strictfp
define spir_kernel void @kernel_0d1d2d3d(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, ptr addrspace(1) %3) local_unnamed_addr #0 !dbg !11 !max_work_group_size !12 !intel_reqd_sub_group_size !13 {
%5 = call i64 @_Z12get_local_idj(i32 0)
%6 = trunc i64 %5 to i32, !dbg !14
%7 = urem i32 %6, 32, !dbg !14
%8 = udiv i32 %6, 32, !dbg !14
%9 = urem i32 %8, 4, !dbg !14
%10 = mul nuw nsw i32 %9, 32, !dbg !14
%11 = add nuw nsw i32 %7, %10, !dbg !14
%12 = getelementptr float, ptr addrspace(1) %0, i32 %11, !dbg !15
%13 = load i32, ptr addrspace(1) %12, align 4, !dbg !16
%14 = bitcast i32 %13 to <1 x float>, !dbg !16
%15 = extractelement <1 x float> %14, i32 0, !dbg !16
%16 = getelementptr float, ptr addrspace(1) %1, i32 %11, !dbg !17
%17 = load i32, ptr addrspace(1) %16, align 4, !dbg !18
%18 = bitcast i32 %17 to <1 x float>, !dbg !18
%19 = extractelement <1 x float> %18, i32 0, !dbg !18
%20 = call spir_func float @__imf_fdiv_rn(float %15, float %19), !dbg !19
%21 = fpext float %15 to double, !dbg !20
%22 = fpext float %19 to double, !dbg !21
%23 = fdiv double %21, %22, !dbg !22
%24 = fptrunc double %23 to float, !dbg !23
%25 = getelementptr float, ptr addrspace(1) %2, i32 %11, !dbg !24
%26 = call i64 @_Z12get_local_idj(i32 0)
%27 = insertelement <1 x float> undef, float %20, i32 0, !dbg !25
%28 = bitcast <1 x float> %27 to i32, !dbg !25
%29 = insertelement <1 x i32> undef, i32 %28, i32 0, !dbg !25
store <1 x i32> %29, ptr addrspace(1) %25, align 4, !dbg !25
%30 = getelementptr float, ptr addrspace(1) %3, i32 %11, !dbg !26
%31 = call i64 @_Z12get_local_idj(i32 0)
%32 = insertelement <1 x float> undef, float %24, i32 0, !dbg !27
%33 = bitcast <1 x float> %32 to i32, !dbg !27
%34 = insertelement <1 x i32> undef, i32 %33, i32 0, !dbg !27
store <1 x i32> %34, ptr addrspace(1) %30, align 4, !dbg !27
ret void, !dbg !28
}
Running pass: InstCombinePass on kernel_0d1d2d3d (33 instructions)
Running analysis: OptimizationRemarkEmitterAnalysis on kernel_0d1d2d3d
Running analysis: AAManager on kernel_0d1d2d3d
Running analysis: BasicAA on kernel_0d1d2d3d
Running analysis: ScopedNoAliasAA on kernel_0d1d2d3d
Running analysis: TypeBasedAA on kernel_0d1d2d3d
Running analysis: OuterAnalysisManagerProxy<llvm::AnalysisManager<llvm::Module>, llvm::Function> on kernel_0d1d2d3d
; *** IR Dump After InstCombinePass on kernel_0d1d2d3d ***
; Function Attrs: strictfp
define spir_kernel void @kernel_0d1d2d3d(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, ptr addrspace(1) %3) local_unnamed_addr #0 !dbg !11 !max_work_group_size !12 !intel_reqd_sub_group_size !13 {
%5 = call i64 @_Z12get_local_idj(i32 0)
%6 = and i64 %5, 127, !dbg !14
%7 = and i64 %5, 127, !dbg !15
%8 = getelementptr float, ptr addrspace(1) %0, i64 %7, !dbg !15
%9 = load <1 x float>, ptr addrspace(1) %8, align 4, !dbg !16
%10 = extractelement <1 x float> %9, i64 0, !dbg !16
%11 = and i64 %5, 127, !dbg !17
%12 = getelementptr float, ptr addrspace(1) %1, i64 %11, !dbg !17
%13 = load <1 x float>, ptr addrspace(1) %12, align 4, !dbg !18
%14 = extractelement <1 x float> %13, i64 0, !dbg !18
%15 = call spir_func float @__imf_fdiv_rn(float %10, float %14), !dbg !19
%16 = fdiv float %10, %14, !dbg !20
%17 = and i64 %5, 127, !dbg !21
%18 = getelementptr float, ptr addrspace(1) %2, i64 %17, !dbg !21
%19 = call i64 @_Z12get_local_idj(i32 0)
store float %15, ptr addrspace(1) %18, align 4, !dbg !22
%20 = getelementptr float, ptr addrspace(1) %3, i64 %6, !dbg !23
%21 = call i64 @_Z12get_local_idj(i32 0)
store float %16, ptr addrspace(1) %20, align 4, !dbg !24
ret void, !dbg !25
}
What we can do:
1. The test case itself focus on the precision of math library.
2. Support more div
in the spec of Triton language on the aspect of the floating mathematic operation.
Workarounds:
1. Use cpu result as reference to pass
2. bypass the pattern "fptrunc(OpI (fpextend x), (fpextend y))"
after discussed, some proposal for next step.
strictfp
.
The LLVM optimization pass instCombine always re-write the pattern "fptrunc(OpI (fpextend x), (fpextend y))" to single-preciesion on both Triton side and IGC side.
On Nvidia, it is lowered to
__nv_fdiv_rn
, we need to lower to an equivalent call.