Closed chriso closed 3 years ago
I've tried exploring branchless versions, but it's always slower because there's no CMOV for an xmm register; you need to CMOV the pointer and load again:
diff --git i/build/sortedset/union16_asm.go w/build/sortedset/union16_asm.go
index 3085cea..ab43a17 100644
--- i/build/sortedset/union16_asm.go
+++ w/build/sortedset/union16_asm.go
@@ -33,6 +33,9 @@ func main() {
Label("loop")
+ src := GP64()
+ MOVQ(a, src)
+
// Compare bytes and extract two masks.
// ne = mask of bytes where a!=b
// lt = mask of bytes where a<b
@@ -54,26 +57,34 @@ func main() {
unequalByteIndex := GP32()
BSFL(unequalMask, unequalByteIndex)
BTSL(unequalByteIndex, lessMask)
- JCS(LabelRef("less"))
- // If b>a, copy and advance a.
- Label("greater")
- VMOVUPS(bItem, Mem{Base: dst})
- ADDQ(Imm(16), dst)
- ADDQ(Imm(16), b)
- CMPQ(b, bEnd)
- JE(LabelRef("done"))
- VMOVUPS(Mem{Base: b}, bItem)
- JMP(LabelRef("loop"))
+ // Write from either a or b depending on which is less.
+ CMOVQCC(b, src)
+ item := XMM()
+ VMOVUPS(Mem{Base: src}, item)
+ VMOVUPS(item, Mem{Base: dst})
- // If a<b, copy and advance a.
- Label("less")
- VMOVUPS(aItem, Mem{Base: dst})
+ // Conditionally advance either A or B.
+ advanceMask := GP64()
+ SBBQ(advanceMask, advanceMask) // advanceMask has CF in all bits
+ advanceA := GP64()
+ advanceB := GP64()
+ MOVQ(U32(16), advanceA)
+ MOVQ(U32(16), advanceB)
+ ANDQ(advanceMask, advanceA)
+ NOTQ(advanceMask)
+ ANDQ(advanceMask, advanceB)
+
+ // Advance pointers and loop.
+ ADDQ(advanceA, a)
+ ADDQ(advanceB, b)
ADDQ(Imm(16), dst)
- ADDQ(Imm(16), a)
CMPQ(a, aEnd)
JE(LabelRef("done"))
+ CMPQ(b, bEnd)
+ JE(LabelRef("done"))
VMOVUPS(Mem{Base: a}, aItem)
+ VMOVUPS(Mem{Base: b}, bItem)
JMP(LabelRef("loop"))
// If a==b, copy either and advance both.
diff --git i/sortedset/union16_amd64.s w/sortedset/union16_amd64.s
index 185f2b7..cb17691 100644
--- i/sortedset/union16_amd64.s
+++ w/sortedset/union16_amd64.s
@@ -3,7 +3,7 @@
#include "textflag.h"
// func union16(dst []byte, a []byte, b []byte) (i int, j int, k int)
-// Requires: AVX
+// Requires: AVX, CMOV
TEXT ·union16(SB), NOSPLIT, $0-96
MOVQ dst_base+0(FP), AX
MOVQ a_base+24(FP), CX
@@ -17,35 +17,38 @@ TEXT ·union16(SB), NOSPLIT, $0-96
VMOVUPS (DX), X2
loop:
+ MOVQ CX, DI
VPCMPEQB X1, X2, X3
VPXOR X3, X0, X3
- VPMINUB X1, X2, X4
- VPCMPEQB X1, X4, X4
- VPAND X4, X3, X4
- VPMOVMSKB X3, DI
- VPMOVMSKB X4, R8
- CMPL DI, $0x00000000
+ VPMINUB X1, X2, X2
+ VPCMPEQB X1, X2, X2
+ VPAND X2, X3, X2
+ VPMOVMSKB X3, R8
+ VPMOVMSKB X2, R9
+ CMPL R8, $0x00000000
JE equal
- BSFL DI, R9
- BTSL R9, R8
- JCS less
- VMOVUPS X2, (AX)
+ BSFL R8, R10
+ BTSL R10, R9
+ CMOVQCC DX, DI
+ VMOVUPS (DI), X1
+ VMOVUPS X1, (AX)
+ SBBQ DI, DI
+ MOVQ $0x00000010, R8
+ MOVQ $0x00000010, R9
+ ANDQ DI, R8
+ NOTQ DI
+ ANDQ DI, R9
+ ADDQ R8, CX
+ ADDQ R9, DX
ADDQ $0x10, AX
- ADDQ $0x10, DX
+ CMPQ CX, BX
+ JE done
CMPQ DX, SI
JE done
+ VMOVUPS (CX), X1
VMOVUPS (DX), X2
JMP loop
-less:
- VMOVUPS X1, (AX)
- ADDQ $0x10, AX
- ADDQ $0x10, CX
- CMPQ CX, BX
- JE done
- VMOVUPS (CX), X1
- JMP loop
-
equal:
VMOVUPS X1, (AX)
ADDQ $0x10, AX
name old time/op new time/op delta
Union/size_16,_with_0%_chance_of_overlap-4 39.5µs ± 0% 47.6µs ± 0% +20.48% (p=0.016 n=4+5)
Union/size_16,_with_10%_chance_of_overlap-4 34.1µs ± 0% 41.6µs ± 1% +22.22% (p=0.008 n=5+5)
Union/size_16,_with_50%_chance_of_overlap-4 14.4µs ± 2% 22.3µs ± 1% +54.87% (p=0.008 n=5+5)
Union/size_16,_with_100%_chance_of_overlap-4 5.42µs ± 7% 7.10µs ± 0% +31.12% (p=0.008 n=5+5)
name old speed new speed delta
Union/size_16,_with_0%_chance_of_overlap-4 3.32GB/s ± 0% 2.75GB/s ± 0% -17.00% (p=0.016 n=4+5)
Union/size_16,_with_10%_chance_of_overlap-4 3.85GB/s ± 0% 3.15GB/s ± 1% -18.18% (p=0.008 n=5+5)
Union/size_16,_with_50%_chance_of_overlap-4 9.11GB/s ± 2% 5.88GB/s ± 1% -35.43% (p=0.008 n=5+5)
Union/size_16,_with_100%_chance_of_overlap-4 24.2GB/s ± 7% 18.5GB/s ± 0% -23.88% (p=0.008 n=5+5)
I've been searching for a "gadget" we can use to do bytewise comparisons of two vector registers. It's applicable in the sorted set routines (union/less) and will be applicable in the sorting routines too.
For union/intersect, we need to know whether
a==b
,a<b
anda>b
so we know which side to copy and which side to advance. In future, we might also like to branch based on whethera>=b
anda<=b
. The ideal gadget would therefore set both ZF and CF in one operation like CMP does. We'd also like it to be as succinct and readable as possible, since it'll be inlined in a few of the sorting routines. I'm not quite there yet but eventually I'd like to find a way to lean onVPTEST
for this task.Here's what we had previously — we extract an equality mask, and use it to determine the index of the first unequal byte, which we then load and compare:
Here's what we have now — we extract two masks (
a!=b
anda<b
) — then find the first set bit in the first mask (which represents an unequal byte) followed by testing that bit in the second mask:It's not better in the case where the inputs are the same, but it's better in almost all other cases: