llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
29.4k stars 12.15k forks source link

Convert clamp into fmaxnum/fminnum pairs. #77158

Closed hiraditya closed 10 months ago

hiraditya commented 11 months ago

test case

https://godbolt.org/z/6c538vfvb

float clamp(float a)                                                                                                                                                                       
{                                                                                                                                                                                             
  const float b = 255.0;                                                                                                                                                                      
  const float c = 0.0;                                                                                                                                                                        

  if(b < a)                                                                                                                                                                                   
    return b;                                                                                                                                                                                 
  if(a < c)                                                                                                                                                                                   
    return c;                                                                                                                                                                                 

  return a;                                                                                                                                                                                   
}

clang compiles to

clamp:                                  // @clamp
        movi    d1, #0000000000000000
        fcmp    s0, #0.0
        mov     w8, #1132396544                 // =0x437f0000
        fmov    s2, w8
        fcsel   s1, s1, s0, mi
        fcmp    s0, s2
        fcsel   s0, s2, s1, gt
        ret

but we could have something like

clamp:                                  // @clamp
// BB#0:                                // %entry
        adrp    x8, .LCPI0_0
        ldr     s1, [x8, :lo12:.LCPI0_0]
        fmov    s2, wzr
        fmaxnm  s0, s0, s2
        fminnm  s0, s0, s1
        ret

@sebpop had a patch a while back https://reviews.llvm.org/D24033 that would do this. Since phabricator is deprecated i'll move the patch here.

hiraditya commented 11 months ago

patch based on llvm-project dated: May 23 2018

Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
===================================================================
--- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -267,6 +267,7 @@
     SDValue PromoteIntShiftOp(SDValue Op);
     SDValue PromoteExtend(SDValue Op);
     bool PromoteLoad(SDValue Op);
+    SDValue combineSelectFP(SDNode *N);

     /// Call the node-specific routine that knows how to fold each
     /// particular type of node. If that doesn't do anything, try the
@@ -6637,6 +6638,87 @@
   return SDValue();
 }

+// Perform checks on select instructions, and replace it with
+// fmin or fmax.
+SDValue DAGCombiner::combineSelectFP(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  SDValue N2 = N->getOperand(2);
+
+  EVT VT = N->getValueType(0);
+  SDNode *CmpNode = N0.getNode();
+
+  // Check if num of operands match requirements
+  if (N2->getNumOperands() < 2 || CmpNode->getOpcode() != ISD::SETCC ||
+      CmpNode->getOperand(0) != N2.getOperand(0))
+    return SDValue();
+
+  // Is it really safe here?
+  ConstantFPSDNode *CmpN1FP =
+      dyn_cast<ConstantFPSDNode>(CmpNode->getOperand(1));
+  ConstantFPSDNode *N1FP = dyn_cast<ConstantFPSDNode>(N1);
+  if (!CmpN1FP || !N1FP)
+    return SDValue();
+
+  const APFloat &CmpN1FPVal = CmpN1FP->getValueAPF();
+  const APFloat &N1FPVal = N1FP->getValueAPF();
+
+  // Check if float point constant are the same
+  if (&CmpN1FPVal.getSemantics() != &N1FPVal.getSemantics() ||
+      CmpN1FPVal.compare(N1FPVal) != APFloat::cmpEqual)
+    return SDValue();
+
+  // Check if the N2 operand is constant float
+  SDNode *N2Node = N2.getNode();
+  ConstantFPSDNode *N2Operand2 =
+      dyn_cast<ConstantFPSDNode>(N2Node->getOperand(1).getNode());
+  if (!N2Operand2)
+    return SDValue();
+
+  // Check the value of cmpN1 and N2 are equal
+  APFloat N2Operand2Val = N2Operand2->getValueAPF();
+  if (&CmpN1FPVal.getSemantics() != &N2Operand2Val.getSemantics())
+    return SDValue();
+
+  unsigned RetOpcode;
+  // Switch based on the comparison operand.
+  switch (cast<CondCodeSDNode>(CmpNode->getOperand(2))->get()) {
+  case ISD::SETOLT:
+  case ISD::SETOLE:
+  case ISD::SETLT:
+  case ISD::SETLE:
+  case ISD::SETULT:
+  case ISD::SETULE: {
+    if (N1FPVal.compare(N2Operand2Val) != APFloat::cmpLessThan)
+      return SDValue();
+
+    RetOpcode = ISD::FMAXNUM;
+    unsigned N2Opcode = N2Node->getOpcode();
+    if (N2Opcode != ISD::FMINNUM && N2Opcode != ISD::FMINNAN)
+      return SDValue();
+  }; break;
+  case ISD::SETOGT:
+  case ISD::SETOGE:
+  case ISD::SETGT:
+  case ISD::SETGE:
+  case ISD::SETUGT:
+  case ISD::SETUGE: {
+    if (N1FPVal.compare(N2Operand2Val) != APFloat::cmpGreaterThan)
+      return SDValue();
+
+    RetOpcode = ISD::FMINNUM;
+    unsigned N2Opcode = N2Node->getOpcode();
+    if (N2Opcode != ISD::FMAXNUM && N2Opcode != ISD::FMAXNAN)
+      return SDValue();
+  }; break;
+  default:
+    return SDValue();
+    break;
+  }
+
+  return DAG.getNode(RetOpcode, SDLoc(N), VT, N2, N1);
+}
+
 SDValue DAGCombiner::visitSELECT(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
@@ -6783,6 +6865,16 @@
       if (SDValue FMinMax = combineMinNumMaxNum(
               DL, VT, N0.getOperand(0), N0.getOperand(1), N1, N2, CC, TLI, DAG))
         return FMinMax;
+
+      //  t5: i1 = setcc t2, ConstantFP:f1, setgt:ch
+      //  t9: f32 = fmaxnum t2, ConstantFP:f2
+      //  t10: f32 = select t5, ConstantFP:f1, t9
+      // and f1 >= f2
+      // ==> t9  = fmaxnum t2, f2
+      //     t10 = fminnum t9, f1
+      if (isa<ConstantFPSDNode>(N1) && N0->getOpcode() == ISD::SETCC)
+        if (SDValue combinedFP = combineSelectFP(N))
+          return combinedFP;
     }

     if ((!LegalOperations &&
Index: llvm/test/CodeGen/AArch64/aarch64-DAGCombine-fminmax.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/AArch64/aarch64-DAGCombine-fminmax.ll
@@ -0,0 +1,88 @@
+; RUN: llc --enable-unsafe-fp-math -mtriple=aarch64-unknown-linux-gnu < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64-unknown-linux-gnu < %s | FileCheck -check-prefix=CHECK1 %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+
+declare float @llvm.minnum.f32(float, float)
+declare float @llvm.maxnum.f32(float, float)
+
+; Function Attrs: norecurse nounwind readnone
+; CHECK:  fmax{{(nm)?}} {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+; CHECK:  fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+; CHECK1:  fmax{{(nm)?}} {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+; CHECK1-NOT:  fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+define float @clampNUM(float %a) local_unnamed_addr #0 {
+entry:
+  %cmp = fcmp ogt float %a, 2.550000e+02
+  %cmp3 = fcmp olt float %a, 1.280000e+02
+  %.a = select i1 %cmp3, float 1.280000e+02, float %a
+  %retval.0 = select i1 %cmp, float 2.550000e+02, float %.a
+  ret float %retval.0
+}
+
+; Function Attrs: norecurse nounwind readnone
+; CHECK:  fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+; CHECK:  fmaxnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+; CHECK1:  fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+; CHECK1-NOT:  fmaxnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+define float @clampNAN(float %a) local_unnamed_addr #0 {
+entry:
+  %cmp = fcmp olt float %a, 1.280000e+02
+  %cmp2 = fcmp fast ogt float %a, 2.550000e+02
+  %.a = select i1 %cmp2, float 2.550000e+02, float %a
+  %retval.0 = select i1 %cmp, float 1.280000e+02, float %.a
+  ret float %retval.0
+}
+
+; Function Attrs: norecurse nounwind readnone
+; CHECK:  fmaxnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+; CHECK:  fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+; CHECK1:  fmaxnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+; CHECK1-NOT:  fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+define float @clampIntrinsicFmax(float %a) local_unnamed_addr #0 {
+entry:
+  %cmp = fcmp ogt float %a, 2.550000e+02
+  %.a = call float @llvm.maxnum.f32(float %a, float 1.280000e+02) readnone
+  %retval.0 = select i1 %cmp, float 2.550000e+02, float %.a
+  ret float %retval.0
+}
+
+; Function Attrs: norecurse nounwind readnone
+; CHECK:  fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+; CHECK:  fmaxnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+; CHECK1:  fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+; CHECK1-NOT:  fmaxnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}}
+define float @clampIntrinsicFmin(float %a) local_unnamed_addr #0 {
+entry:
+  %cmp = fcmp olt float %a, 1.280000e+02
+  %.a = call float @llvm.minnum.f32(float %a, float 2.550000e+02) readnone
+  %retval.0 = select i1 %cmp, float 1.280000e+02, float %.a
+  ret float %retval.0
+}
+
+; Function Attrs: norecurse nounwind readnone
+; CHECK-NOT:  fmin
+; CHECK1-NOT:  fmin
+define double @clampNoConvert(float %a) local_unnamed_addr #0 {
+entry:
+  %cmp = fcmp ogt float %a, 2.550000e+02
+  %.inv = fcmp ole float %a, 3.000000e+00
+  %0 = select i1 %.inv, float 3.000000e+00, float %a
+  %1 = fpext float %0 to double
+  %retval.0 = select i1 %cmp, double 2.550000e+02, double %1
+  ret double %retval.0
+}
+
+; Function Attrs: norecurse nounwind readnone
+; CHECK-NOT:  fmin
+; CHECK1-NOT:  fmin
+define float @clampNo2Convert(float %a) local_unnamed_addr #0 {
+entry:
+  %cmp = fcmp ogt float %a, 2.550000e+02
+  %cmp2 = fcmp fast olt float %a, 0.000000e+00
+  %.a = select i1 %cmp2, float 3.000000e+01, float %a
+  %retval.0 = select i1 %cmp, float 2.550000e+02, float %.a
+  ret float %retval.0
+}
+
+attributes #0 = { norecurse nounwind readnone "no-nans-fp-math"="true" }
vfdff commented 10 months ago

it seems llvm already works with -fno-signed-zeros -ffinite-math-only (or -ffast-math), https://godbolt.org/z/hczj9efoT

hiraditya commented 10 months ago

looks good. closing it then.