llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
28.03k stars 11.58k forks source link

[MIPS] miscompile of 64-bit shift with masked shift amount #64794

Closed jacobly0 closed 10 months ago

jacobly0 commented 1 year ago
target triple = "mips"

define i64 @f(i64 %0) {
  %2 = and i64 %0, 63
  %3 = lshr i64 -1, %2
  ret i64 %3
}

define i32 @main() {
  %1 = call i64 @f(i64 12)
  %2 = icmp ne i64 %1, lshr (i64 -1, i64 12)
  %3 = zext i1 %2 to i32
  ret i32 %3
}

version 16.0.0 (https://github.com/llvm/llvm-project.git 08d094a0e457360ad8b94b017d2dc277e697ca76) returns 0 version 17.x (https://github.com/llvm/llvm-project.git 8f4dd44097c9ae25dd203d5ac87f3b48f854bba8) returns 1

        andi    $1, $5, 63
        addiu   $2, $zero, -1
-       srlv    $2, $2, $1
-       not     $1, $1
-       addiu   $3, $zero, -2
-       sllv    $1, $3, $1
-       or      $3, $1, $2
+       srlv    $3, $2, $1
        andi    $1, $5, 32
-       movn    $3, $2, $1
+       move    $2, $3
        jr      $ra
        movn    $2, $zero, $1
llvmbot commented 1 year ago

@llvm/issue-subscribers-backend-mips

jacobly0 commented 1 year ago

These changes get downstream tests passing:

diff --git a/llvm/lib/Target/Mips/MipsISelLowering.cpp b/llvm/lib/Target/Mips/MipsISelLowering.cpp
index 18d7773067f1..d92e94a353bf 100644
--- a/llvm/lib/Target/Mips/MipsISelLowering.cpp
+++ b/llvm/lib/Target/Mips/MipsISelLowering.cpp
@@ -2593,16 +2593,20 @@ SDValue MipsTargetLowering::lowerShiftLeftParts(SDValue Op,
   SDValue Shamt = Op.getOperand(2);
   // if shamt < (VT.bits):
   //  lo = (shl lo, shamt)
-  //  hi = (or (shl hi, shamt) (srl (srl lo, 1), ~shamt))
+  //  hi = (or (shl hi, shamt) (srl (srl lo, 1), (xor shamt, VT.bits-1)))
   // else:
   //  lo = 0
   //  hi = (shl lo, shamt[4:0])
-  SDValue Not = DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
-                            DAG.getConstant(-1, DL, MVT::i32));
+  SDValue Not =
+      DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
+                  DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
   SDValue ShiftRight1Lo = DAG.getNode(ISD::SRL, DL, VT, Lo,
                                       DAG.getConstant(1, DL, VT));
   SDValue ShiftRightLo = DAG.getNode(ISD::SRL, DL, VT, ShiftRight1Lo, Not);
-  SDValue ShiftLeftHi = DAG.getNode(ISD::SHL, DL, VT, Hi, Shamt);
+  SDValue ShamtMasked =
+      DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
+                  DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
+  SDValue ShiftLeftHi = DAG.getNode(ISD::SHL, DL, VT, Hi, ShamtMasked);
   SDValue Or = DAG.getNode(ISD::OR, DL, VT, ShiftLeftHi, ShiftRightLo);
   SDValue ShiftLeftLo = DAG.getNode(ISD::SHL, DL, VT, Lo, Shamt);
   SDValue Cond = DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
@@ -2623,7 +2627,7 @@ SDValue MipsTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
   MVT VT = Subtarget.isGP64bit() ? MVT::i64 : MVT::i32;

   // if shamt < (VT.bits):
-  //  lo = (or (shl (shl hi, 1), ~shamt) (srl lo, shamt))
+  //  lo = (or (shl (shl hi, 1), (xor shamt, VT.bits-1)) (srl lo, shamt))
   //  if isSRA:
   //    hi = (sra hi, shamt)
   //  else:
@@ -2635,15 +2639,19 @@ SDValue MipsTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
   //  else:
   //   lo = (srl hi, shamt[4:0])
   //   hi = 0
-  SDValue Not = DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
-                            DAG.getConstant(-1, DL, MVT::i32));
+  SDValue Not =
+      DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
+                  DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
   SDValue ShiftLeft1Hi = DAG.getNode(ISD::SHL, DL, VT, Hi,
                                      DAG.getConstant(1, DL, VT));
   SDValue ShiftLeftHi = DAG.getNode(ISD::SHL, DL, VT, ShiftLeft1Hi, Not);
   SDValue ShiftRightLo = DAG.getNode(ISD::SRL, DL, VT, Lo, Shamt);
   SDValue Or = DAG.getNode(ISD::OR, DL, VT, ShiftLeftHi, ShiftRightLo);
-  SDValue ShiftRightHi = DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL,
-                                     DL, VT, Hi, Shamt);
+  SDValue ShamtMasked =
+      DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
+                  DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
+  SDValue ShiftRightHi =
+      DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, VT, Hi, ShamtMasked);
   SDValue Cond = DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
                              DAG.getConstant(VT.getSizeInBits(), DL, MVT::i32));
   SDValue Ext = DAG.getNode(ISD::SRA, DL, VT, Hi,
tru commented 1 year ago

Is this fix posted to phabricator somewhere? should we still try to get this fix into 17.x?

brad0 commented 1 year ago

@wzssyqa

brad0 commented 1 year ago

@FlyGoat

FlyGoat commented 1 year ago

@wzssyqa Do you mind to submit it?

yingopq commented 1 year ago

@jacobly0 I could not reproduce this issue on mips64el. What is your steps? How can you get result of the forward .ll file? I added printf in .ll file and used clang to obtain a.out. Thanks.

jacobly0 commented 1 year ago
target triple = "mips-pc-linux"

define i64 @f(i64 %0) {
  %2 = and i64 %0, 63
  %3 = lshr i64 -1, %2
  ret i64 %3
}

define void @__start() {
  %1 = call i64 @f(i64 12)
  %2 = icmp ne i64 %1, lshr (i64 -1, i64 12)
  %3 = zext i1 %2 to i32
  call void asm sideeffect "syscall", "{$2},{$4}"(i32 4001, i32 %3)
  unreachable
}
$ llc-16 repro.ll -o 16.s
$ llc-17 repro.ll -o 17.s
$ diff -U2 16.s 17.s
--- 16.s
+++ 17.s
@@ -23,11 +23,7 @@
    andi    $1, $5, 63
    addiu   $2, $zero, -1
-   srlv    $2, $2, $1
-   not $1, $1
-   addiu   $3, $zero, -2
-   sllv    $1, $3, $1
-   or  $3, $1, $2
+   srlv    $3, $2, $1
    andi    $1, $5, 32
-   movn    $3, $2, $1
+   move    $2, $3
    jr  $ra
    movn    $2, $zero, $1
$ clang-16 -nostdlib -static -fuse-ld=lld -target mips-pc-linux repro.ll && qemu-mips ./a.out; echo $?
0
$ clang-17 -nostdlib -static -fuse-ld=lld -target mips-pc-linux repro.ll && qemu-mips ./a.out; echo $?
1
yingopq commented 1 year ago

@jacobly0 I used mips64el and did not reproduce, I would try mips.

$ sudo ./install-ninja/bin/clang-17 -static -fuse-ld=/usr/bin/mips64el-linux-gnuabi64-ld -target mips64el-unknown-linux-gnuabi64 1.ll -o main3 && qemu-mips64el ./main3; echo $?
fffffffffffff
fffffffffffff
0
0
$ cat 1.ll
target triple = "mips64el-unknown-linux-gnuabi64"

@.str = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1
@.str64 = private unnamed_addr constant [5 x i8] c"%lx\0A\00", align 1

define i64 @f(i64 %0) {
  %2 = and i64 %0, 63
  %3 = lshr i64 -1, %2
  %4 = call signext i32 (i8*, ...) @printf(i8* getelementptr inbounds ([5 x i8], [5 x i8]* @.str64, i64 0, i64 0), i64 signext %3)
  %5 = lshr i64 -1, %0
  %6 = call signext i32 (i8*, ...) @printf(i8* getelementptr inbounds ([5 x i8], [5 x i8]* @.str64, i64 0, i64 0), i64 signext %5)
  ret i64 %3
}

define i32 @main() {
  %1 = call i64 @f(i64 12)
  %2 = icmp ne i64 %1, lshr (i64 -1, i64 12)
  %3 = zext i1 %2 to i32
  %4 = call signext i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @.str, i64 0, i64 0), i32 signext %3)
  ret i32 %3
}

declare signext i32 @printf(i8*, ...)
jacobly0 commented 1 year ago

There's no way my repro exhibits the issue on mips64 because it happens after splitting a 64-bit shift.

yingopq commented 12 months ago

@jacobly0 Yes, I reproduced on mips32el and did little change about diff and the result was same OK. How did you think my diff? If OK, I would submit it. Thanks.

diff --git a/llvm/lib/Target/Mips/MipsISelLowering.cpp b/llvm/lib/Target/Mips/MipsISelLowering.cpp
index 18d7773067f1..a0bfeb1dc3f0 100644
--- a/llvm/lib/Target/Mips/MipsISelLowering.cpp
+++ b/llvm/lib/Target/Mips/MipsISelLowering.cpp
@@ -2593,23 +2593,29 @@ SDValue MipsTargetLowering::lowerShiftLeftParts(SDValue Op,
   SDValue Shamt = Op.getOperand(2);
   // if shamt < (VT.bits):
   //  lo = (shl lo, shamt)
-  //  hi = (or (shl hi, shamt) (srl (srl lo, 1), ~shamt))
+  //  hi = (or (shl hi, shamt) (srl (srl lo, 1), (xor shamt, VT.bits-1))))
   // else:
   //  lo = 0
   //  hi = (shl lo, shamt[4:0])
   SDValue Not = DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
-                            DAG.getConstant(-1, DL, MVT::i32));
+                            DAG.getConstant(VT.getSizeInBits()-1, DL, MVT::i32));
   SDValue ShiftRight1Lo = DAG.getNode(ISD::SRL, DL, VT, Lo,
                                       DAG.getConstant(1, DL, VT));
   SDValue ShiftRightLo = DAG.getNode(ISD::SRL, DL, VT, ShiftRight1Lo, Not);
+  SDValue ShamtMasked =
+      DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
+                  DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
+  SDValue HiTrue =
+      DAG.getNode(ISD::SHL, DL, VT, Hi, ShamtMasked);
   SDValue ShiftLeftHi = DAG.getNode(ISD::SHL, DL, VT, Hi, Shamt);
   SDValue Or = DAG.getNode(ISD::OR, DL, VT, ShiftLeftHi, ShiftRightLo);
   SDValue ShiftLeftLo = DAG.getNode(ISD::SHL, DL, VT, Lo, Shamt);
   SDValue Cond = DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
                              DAG.getConstant(VT.getSizeInBits(), DL, MVT::i32));
+
   Lo = DAG.getNode(ISD::SELECT, DL, VT, Cond,
                    DAG.getConstant(0, DL, VT), ShiftLeftLo);
-  Hi = DAG.getNode(ISD::SELECT, DL, VT, Cond, ShiftLeftLo, Or);
+  Hi = DAG.getNode(ISD::SELECT, DL, VT, Cond, HiTrue, Or);

   SDValue Ops[2] = {Lo, Hi};
   return DAG.getMergeValues(Ops, DL);
@@ -2623,7 +2629,7 @@ SDValue MipsTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
   MVT VT = Subtarget.isGP64bit() ? MVT::i64 : MVT::i32;

   // if shamt < (VT.bits):
-  //  lo = (or (shl (shl hi, 1), ~shamt) (srl lo, shamt))
+  //  lo = (or (shl (shl hi, 1), (xor shamt, VT.bits-1))) (srl lo, shamt))
   //  if isSRA:
   //    hi = (sra hi, shamt)
   //  else:
@@ -2636,12 +2642,17 @@ SDValue MipsTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
   //   lo = (srl hi, shamt[4:0])
   //   hi = 0
   SDValue Not = DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
-                            DAG.getConstant(-1, DL, MVT::i32));
+                            DAG.getConstant(VT.getSizeInBits()-1, DL, MVT::i32));
   SDValue ShiftLeft1Hi = DAG.getNode(ISD::SHL, DL, VT, Hi,
                                      DAG.getConstant(1, DL, VT));
   SDValue ShiftLeftHi = DAG.getNode(ISD::SHL, DL, VT, ShiftLeft1Hi, Not);
   SDValue ShiftRightLo = DAG.getNode(ISD::SRL, DL, VT, Lo, Shamt);
   SDValue Or = DAG.getNode(ISD::OR, DL, VT, ShiftLeftHi, ShiftRightLo);
+  SDValue ShamtMasked =
+      DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
+                  DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
+  SDValue LoTrue =
+      DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, VT, Hi, ShamtMasked);
   SDValue ShiftRightHi = DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL,
                                      DL, VT, Hi, Shamt);
   SDValue Cond = DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
@@ -2658,7 +2669,7 @@ SDValue MipsTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
                        ShiftRightHi);
   }

-  Lo = DAG.getNode(ISD::SELECT, DL, VT, Cond, ShiftRightHi, Or);
+  Lo = DAG.getNode(ISD::SELECT, DL, VT, Cond, LoTrue, Or);
   Hi = DAG.getNode(ISD::SELECT, DL, VT, Cond,
                    IsSRA ? Ext : DAG.getConstant(0, DL, VT), ShiftRightHi);
jacobly0 commented 12 months ago

Note that when shamt < 32 then shamt is necessarily equal to shamt[4:0], which is why the same masked shift can be reused for both cases. The select does stop the propagation of this poison in the other case, so this diff is certainly valid, it just generates more instructions.

It would be nice to lower these masked shifts later into just a mips shift instruction, since the cpu already implicitly masks shifts, which you can see being done in other target InstrInfo.td files. This would allow the backend to generate the original optimized instruction sequence again without risking misoptimizations by shared optimization passes, but I don't think that needs to make it into a release.

yingopq commented 12 months ago

Note that when shamt < 32 then shamt is necessarily equal to shamt[4:0], which is why the same masked shift can be reused for both cases. The select does stop the propagation of this poison in the other case, so this diff is certainly valid, it just generates more instructions.

@jacobly0 In function MipsTargetLowering::lowerShiftLeftParts, should we modify ShiftLeftLo rather than ShiftLeftHi according to your idea which would reduce instructions? It would be nice to lower these masked shifts later into just a mips shift instruction, since the cpu already implicitly masks shifts, which you can see being done in other target InstrInfo.td files. This would allow the backend to generate the original optimized instruction sequence again without risking misoptimizations by shared optimization passes, but I don't think that needs to make it into a release.

I did not undertstand clearly, you mean we did not need to commit these codes to release 17.x?

jacobly0 commented 11 months ago

In function MipsTargetLowering::lowerShiftLeftParts, should we modify ShiftLeftLo rather than ShiftLeftHi according to your idea which would reduce instructions?

Sorry, yeah, I meant to make the same change to both functions:

diff --git a/llvm/lib/Target/Mips/MipsISelLowering.cpp b/llvm/lib/Target/Mips/MipsISelLowering.cpp
index 18d7773067f1..480861156eb6 100644
--- a/llvm/lib/Target/Mips/MipsISelLowering.cpp
+++ b/llvm/lib/Target/Mips/MipsISelLowering.cpp
@@ -2593,18 +2593,22 @@ SDValue MipsTargetLowering::lowerShiftLeftParts(SDValue Op,
   SDValue Shamt = Op.getOperand(2);
   // if shamt < (VT.bits):
   //  lo = (shl lo, shamt)
-  //  hi = (or (shl hi, shamt) (srl (srl lo, 1), ~shamt))
+  //  hi = (or (shl hi, shamt) (srl (srl lo, 1), (xor shamt, VT.bits-1)))
   // else:
   //  lo = 0
   //  hi = (shl lo, shamt[4:0])
-  SDValue Not = DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
-                            DAG.getConstant(-1, DL, MVT::i32));
+  SDValue Not =
+      DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
+                  DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
   SDValue ShiftRight1Lo = DAG.getNode(ISD::SRL, DL, VT, Lo,
                                       DAG.getConstant(1, DL, VT));
   SDValue ShiftRightLo = DAG.getNode(ISD::SRL, DL, VT, ShiftRight1Lo, Not);
   SDValue ShiftLeftHi = DAG.getNode(ISD::SHL, DL, VT, Hi, Shamt);
   SDValue Or = DAG.getNode(ISD::OR, DL, VT, ShiftLeftHi, ShiftRightLo);
-  SDValue ShiftLeftLo = DAG.getNode(ISD::SHL, DL, VT, Lo, Shamt);
+  SDValue ShamtMasked =
+      DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
+                  DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
+  SDValue ShiftLeftLo = DAG.getNode(ISD::SHL, DL, VT, Lo, ShamtMasked);
   SDValue Cond = DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
                              DAG.getConstant(VT.getSizeInBits(), DL, MVT::i32));
   Lo = DAG.getNode(ISD::SELECT, DL, VT, Cond,
@@ -2623,7 +2627,7 @@ SDValue MipsTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
   MVT VT = Subtarget.isGP64bit() ? MVT::i64 : MVT::i32;

   // if shamt < (VT.bits):
-  //  lo = (or (shl (shl hi, 1), ~shamt) (srl lo, shamt))
+  //  lo = (or (shl (shl hi, 1), (xor shamt, VT.bits-1)) (srl lo, shamt))
   //  if isSRA:
   //    hi = (sra hi, shamt)
   //  else:
@@ -2635,15 +2639,19 @@ SDValue MipsTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
   //  else:
   //   lo = (srl hi, shamt[4:0])
   //   hi = 0
-  SDValue Not = DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
-                            DAG.getConstant(-1, DL, MVT::i32));
+  SDValue Not =
+      DAG.getNode(ISD::XOR, DL, MVT::i32, Shamt,
+                  DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
   SDValue ShiftLeft1Hi = DAG.getNode(ISD::SHL, DL, VT, Hi,
                                      DAG.getConstant(1, DL, VT));
   SDValue ShiftLeftHi = DAG.getNode(ISD::SHL, DL, VT, ShiftLeft1Hi, Not);
   SDValue ShiftRightLo = DAG.getNode(ISD::SRL, DL, VT, Lo, Shamt);
   SDValue Or = DAG.getNode(ISD::OR, DL, VT, ShiftLeftHi, ShiftRightLo);
-  SDValue ShiftRightHi = DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL,
-                                     DL, VT, Hi, Shamt);
+  SDValue ShamtMasked =
+      DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
+                  DAG.getConstant(VT.getSizeInBits() - 1, DL, MVT::i32));
+  SDValue ShiftRightHi =
+      DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, VT, Hi, ShamtMasked);
   SDValue Cond = DAG.getNode(ISD::AND, DL, MVT::i32, Shamt,
                              DAG.getConstant(VT.getSizeInBits(), DL, MVT::i32));
   SDValue Ext = DAG.getNode(ISD::SRA, DL, VT, Hi,

I did not undertstand clearly, you mean we did not need to commit these codes to release 17.x?

I'm suggesting a potential future optimization to turn x >> (shift & 31) into a single srlv instead of andi + srlv, since the mips variable shift instruction already implicitly masks the shift, like what is already done with x86 and wasm. Since this would be strictly an optimization, I'm saying it's not needed in the release for correct behavior.

yingopq commented 11 months ago

@jacobly0 I appreciate your latest diff, and how about I submit this diff and then research the optimizition?

jacobly0 commented 11 months ago

Yes, that sound good to me.

yingopq commented 10 months ago

@jacobly0 Why we add ShamtMasked, the mips cpu has done this when shift > bits?

I'm suggesting a potential future optimization to turn x >> (shift & 31) into a single srlv instead of andi + srlv

This means modify all relateds .ll testfiles?

jacobly0 commented 10 months ago

Why we add ShamtMasked, the mips cpu has done this when shift > bits?

Because ISD::SRL is an llvm shift, not a mips shift, so other llvm passes assume that a shift larger than the bit size is UB and will happily delete such an instruction. Instead, you have to represent the operation that is actually happening, which is ISD::SRL of ISD::AND, and then during instruction selection you know you have a single mips instruction that performs both operations.

nikic commented 10 months ago

/cherry-pick 8d24d3900ec3f28902b2fad4a2c2c2b789257424

llvmbot commented 10 months ago

/branch llvm/llvm-project-release-prs/issue64794

llvmbot commented 10 months ago

/pull-request llvm/llvm-project-release-prs#768