Open Quuxplusone opened 3 years ago
Bugzilla Link | PR52189 |
Status | NEW |
Importance | P enhancement |
Reported by | Roman Lebedev (lebedev.ri@gmail.com) |
Reported on | 2021-10-15 14:48:38 -0700 |
Last modified on | 2021-10-19 11:36:58 -0700 |
Version | trunk |
Hardware | PC Linux |
CC | andrew.b.adams@gmail.com, craig.topper@gmail.com, llvm-bugs@lists.llvm.org, llvm-dev@redking.me.uk, pengfei.wang@intel.com, spatel+llvm@rotateright.com |
Fixed by commit(s) | |
Attachments | |
Blocks | |
Blocked by | |
See also |
It's going to be painful to verify, so i'm not sure i will want to deal with this, but here's the basic brute-force proofing framework:
$ cat Makefile all: set -ex @for D in $$(seq 1 255); do echo "D: $$D"; rm a.out; /builddirs/llvm-project/build-Clang13/bin/clang++ main.cpp -DTYPE=uint8_t -DDIVISOR=$$D -g0 -O3; ./a.out; done;
$ cat main.cpp
attribute((optnone, noinline)) TYPE src(TYPE x) { return (TYPE)x / (TYPE)(DIVISOR); }
attribute((noinline)) TYPE tgt(TYPE x) { return (TYPE)x / (TYPE)(DIVISOR); }
int main() { for (uint64_t x = 0; x <= (TYPE)~0ULL; ++x) { TYPE origres = src(x); TYPE newres = tgt(x); if (origres == newres) continue; printf("mismatch at %lu / %lu: %lu vs %lu\n", (uint64_t)x, (uint64_t)DIVISOR, (uint64_t)origres, (uint64_t)newres); abort(); } return 0; }
llvm-project$ git diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index fa1c3cf32dee..dfd0c772ae1f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5410,21 +5410,38 @@ SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG,
Created.push_back(Q.getNode());
- if (UseNPQ) {
+ if (UseNPQ && VT.isVector()) {
SDValue NPQ = DAG.getNode(ISD::SUB, dl, VT, N0, Q);
Created.push_back(NPQ.getNode());
// For vectors we might have a mix of non-NPQ/NPQ paths, so use
// MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
- if (VT.isVector())
- NPQ = GetMULHU(NPQ, NPQFactor);
- else
- NPQ = DAG.getNode(ISD::SRL, dl, VT, NPQ, DAG.getConstant(1, dl, ShVT));
+ NPQ = GetMULHU(NPQ, NPQFactor);
Created.push_back(NPQ.getNode());
Q = DAG.getNode(ISD::ADD, dl, VT, NPQ, Q);
Created.push_back(Q.getNode());
+ } else if (UseNPQ && !VT.isVector()) {
+ EVT AvgVT = EVT::getIntegerVT(*DAG.getContext(), 2 * EltBits);
+
+ SDValue WideN0 = DAG.getNode(ISD::ZERO_EXTEND, dl, AvgVT, N0);
+ Created.push_back(WideN0.getNode());
+
+ SDValue WideQ = DAG.getNode(ISD::ZERO_EXTEND, dl, AvgVT, Q);
+ Created.push_back(WideQ.getNode());
+
+ SDValue W = DAG.getNode(ISD::ADD, dl, AvgVT, WideN0, WideQ);
+ Created.push_back(W.getNode());
+
+ W = DAG.getNode(ISD::ADD, dl, AvgVT, W, DAG.getConstant(0 , dl, AvgVT)); // <- and no longer an average
+ Created.push_back(W.getNode());
+
+ W = DAG.getNode(ISD::SRL, dl, AvgVT, W, DAG.getConstant(1, dl, AvgVT));
+ Created.push_back(W.getNode());
+
+ Q = DAG.getNode(ISD::TRUNCATE, dl, VT, W);
+ Created.push_back(Q.getNode());
}
Q = DAG.getNode(ISD::SRL, dl, VT, Q, PostShift);
Here's how I brute-force verified it for numerators up to 2^32 and denominators up to 256 (using the commented-out versions of the for loops):
https://github.com/halide/Halide/blob/master/tools/find_inverse.cpp
IIRC it took about an hour to run. An alive proof would be nicer.
LLVM currently uses something that seems equivalent to u_method_2. This new alternative that uses two fewer x86 instructions is u_method_3. Note that the multiplier for it is one less than for u_method_2, so it's not just a matter of swapping the average-down for an average-up.