Open LebedevRI opened 3 years ago
Here's how I brute-force verified it for numerators up to 2^32 and denominators up to 256 (using the commented-out versions of the for loops):
https://github.com/halide/Halide/blob/master/tools/find_inverse.cpp
IIRC it took about an hour to run. An alive proof would be nicer.
LLVM currently uses something that seems equivalent to u_method_2. This new alternative that uses two fewer x86 instructions is u_method_3. Note that the multiplier for it is one less than for u_method_2, so it's not just a matter of swapping the average-down for an average-up.
It's going to be painful to verify, so i'm not sure i will want to deal with this, but here's the basic brute-force proofing framework:
$ cat Makefile
all:
set -ex
@for D in $$(seq 1 255); do echo "D: $$D"; rm a.out; /builddirs/llvm-project/build-Clang13/bin/clang++ main.cpp -DTYPE=uint8_t -DDIVISOR=$$D -g0 -O3; ./a.out; done;
$ cat main.cpp
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
__attribute__((optnone, noinline)) TYPE src(TYPE x) {
return (TYPE)x / (TYPE)(DIVISOR);
}
__attribute__((noinline)) TYPE tgt(TYPE x) { return (TYPE)x / (TYPE)(DIVISOR); }
int main() {
for (uint64_t x = 0; x <= (TYPE)~0ULL; ++x) {
TYPE origres = src(x);
TYPE newres = tgt(x);
if (origres == newres)
continue;
printf("mismatch at %lu / %lu: %lu vs %lu\n", (uint64_t)x,
(uint64_t)DIVISOR, (uint64_t)origres, (uint64_t)newres);
abort();
}
return 0;
}
llvm-project$ git diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index fa1c3cf32dee..dfd0c772ae1f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5410,21 +5410,38 @@ SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG,
Created.push_back(Q.getNode());
- if (UseNPQ) {
+ if (UseNPQ && VT.isVector()) {
SDValue NPQ = DAG.getNode(ISD::SUB, dl, VT, N0, Q);
Created.push_back(NPQ.getNode());
// For vectors we might have a mix of non-NPQ/NPQ paths, so use
// MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
- if (VT.isVector())
- NPQ = GetMULHU(NPQ, NPQFactor);
- else
- NPQ = DAG.getNode(ISD::SRL, dl, VT, NPQ, DAG.getConstant(1, dl, ShVT));
+ NPQ = GetMULHU(NPQ, NPQFactor);
Created.push_back(NPQ.getNode());
Q = DAG.getNode(ISD::ADD, dl, VT, NPQ, Q);
Created.push_back(Q.getNode());
+ } else if (UseNPQ && !VT.isVector()) {
+ EVT AvgVT = EVT::getIntegerVT(*DAG.getContext(), 2 * EltBits);
+
+ SDValue WideN0 = DAG.getNode(ISD::ZERO_EXTEND, dl, AvgVT, N0);
+ Created.push_back(WideN0.getNode());
+
+ SDValue WideQ = DAG.getNode(ISD::ZERO_EXTEND, dl, AvgVT, Q);
+ Created.push_back(WideQ.getNode());
+
+ SDValue W = DAG.getNode(ISD::ADD, dl, AvgVT, WideN0, WideQ);
+ Created.push_back(W.getNode());
+
+ W = DAG.getNode(ISD::ADD, dl, AvgVT, W, DAG.getConstant(0 , dl, AvgVT)); // <- and no longer an average
+ Created.push_back(W.getNode());
+
+ W = DAG.getNode(ISD::SRL, dl, AvgVT, W, DAG.getConstant(1, dl, AvgVT));
+ Created.push_back(W.getNode());
+
+ Q = DAG.getNode(ISD::TRUNCATE, dl, VT, W);
+ Created.push_back(Q.getNode());
}
Q = DAG.getNode(ISD::SRL, dl, VT, Q, PostShift);
Extended Description
As noted in https://github.com/halide/Halide/pull/6322#issuecomment-944664318 didn't look into alive proof.