llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
29.25k stars 12.08k forks source link

AArch64 loads right shifted index generated redundant code #66318

Open linzj opened 1 year ago

linzj commented 1 year ago

Sample C++ code:

extern void* bar();
void* get(unsigned* tag, void** array) {
  unsigned class_id = tag[0] >> 12;
  // if (class_id > 0x40000)
  //   return bar();
  unsigned base = (0x7a8U + (0x9U << 16U)) >> 3;
  return array[base + class_id];
}

Generates asm:

    ldr w9, [x0]
    mov w8, #1960                       // =0x7a8
    movk    w8, #9, lsl #16
    lsr x9, x9, #9
    and x9, x9, #0x7ffff8
    add x9, x9, x1
    ldr x0, [x9, x8]
    ret

I think and x9, x9, #0x7ffff8 is redundant, and needs to be removed.

like the following:

    ldr w8, [x0]
    mov w9, #8437                       // =0x20f5
    movk    w9, #1, lsl #16
    add x8, x9, x8, lsr #12
    ldr x0, [x1, x8, lsl #3]
linzj commented 1 year ago

I am now using the patch to fixed this:

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c7a6dd7deb45..617ce528e485 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -15624,6 +15624,27 @@ AArch64TargetLowering::isDesirableToCommuteWithShift(const SDNode *N,
   SDValue ShiftLHS = N->getOperand(0);
   EVT VT = N->getValueType(0);

+  // If shift are used in the pattern
+  // load (add t0, (shl x, c0))
+  // we will return false.
+  if (N->hasOneUse() && (isa<ConstantSDNode>(N->getOperand(0)) ||
+                         isa<ConstantSDNode>(N->getOperand(1)))) {
+    SDNode *User = *N->use_begin();
+    unsigned LevelCount = 0;
+    while (LevelCount < 2 &&
+           (User->getOpcode() == ISD::ADD || User->getOpcode() == ISD::SUB ||
+            User->getOpcode() == ISD::ZERO_EXTEND) &&
+           User->hasOneUse()) {
+      SDNode *Load = *User->use_begin();
+      if (Load->getOpcode() == ISD::LOAD || Load->getOpcode() == ISD::STORE)
+        return false;
+      User = Load;
+    }
+    // CopyToReg means export, let's be conservative.
+    if (User->getOpcode() == ISD::CopyToReg)
+      return false;
+  }
+
   // If ShiftLHS is unsigned bit extraction: ((x >> C) & mask), then do not
   // combine it with shift 'N' to let it be lowered to UBFX except:
   // ((x >> C) & mask) << C.
@@ -15680,6 +15701,26 @@ bool AArch64TargetLowering::shouldFoldConstantShiftPairToMask(
   if (!N->getOperand(0)->hasOneUse())
     return false;

+  // If shift are used in the pattern
+  // load (add t0, (shift x, c0))
+  // we will return false.
+  if (N->hasOneUse() && isa<ConstantSDNode>(N->getOperand(1))) {
+    SDNode *User = *N->use_begin();
+    unsigned LevelCount = 0;
+    while (LevelCount < 2 &&
+           (User->getOpcode() == ISD::ADD || User->getOpcode() == ISD::SUB ||
+            User->getOpcode() == ISD::ZERO_EXTEND) &&
+           User->hasOneUse()) {
+      SDNode *Load = *User->use_begin();
+      if (Load->getOpcode() == ISD::LOAD || Load->getOpcode() == ISD::STORE)
+        return false;
+      User = Load;
+    }
+    // CopyToReg means export, let's be conservative.
+    if (User->getOpcode() == ISD::CopyToReg)
+      return false;
+  }
+
   // Only fold srl(shl(x,c1),c2) iff C1 >= C2 to prevent loss of UBFX patterns.
   EVT VT = N->getValueType(0);
   if (N->getOpcode() == ISD::SRL && (VT == MVT::i32 || VT == MVT::i64)) {
llvmbot commented 1 year ago

@llvm/issue-subscribers-backend-aarch64

Sample C++ code: ```c++ extern void* bar(); void* get(unsigned* tag, void** array) { unsigned class_id = tag[0] >> 12; // if (class_id > 0x40000) // return bar(); unsigned base = (0x7a8U + (0x9U << 16U)) >> 3; return array[base + class_id]; } ``` Generates asm: ```asm ldr w9, [x0] mov w8, #1960 // =0x7a8 movk w8, #9, lsl #16 lsr x9, x9, #9 and x9, x9, #0x7ffff8 add x9, x9, x1 ldr x0, [x9, x8] ret ``` I think `and x9, x9, #0x7ffff8` is redundant, and needs to be removed. like the following: ```asm ldr w8, [x0] mov w9, #8437 // =0x20f5 movk w9, #1, lsl #16 add x8, x9, x8, lsr #12 ldr x0, [x1, x8, lsl #3] ```