segmentio / asm

Go library providing algorithms optimized to leverage the characteristics of modern CPUs
MIT No Attribution
869 stars 36 forks source link

Extract a less mask and bit test rather than load/cmp bytes #21

Closed chriso closed 3 years ago

chriso commented 3 years ago

I've been searching for a "gadget" we can use to do bytewise comparisons of two vector registers. It's applicable in the sorted set routines (union/less) and will be applicable in the sorting routines too.

For union/intersect, we need to know whether a==b, a<b and a>b so we know which side to copy and which side to advance. In future, we might also like to branch based on whether a>=b and a<=b. The ideal gadget would therefore set both ZF and CF in one operation like CMP does. We'd also like it to be as succinct and readable as possible, since it'll be inlined in a few of the sorting routines. I'm not quite there yet but eventually I'd like to find a way to lean on VPTEST for this task.

Here's what we had previously — we extract an equality mask, and use it to determine the index of the first unequal byte, which we then load and compare:

result := XMM()
VPCMPEQB(Mem{Base: a}, b, result)

mask := GP32()
VPMOVMSKB(result, mask)

CMPL(mask, U32(0xFFFF))
JE(LabelRef("equal"))
NOTL(mask)
unequalByteIndex := GP32()
BSFL(mask, unequalByteIndex)
aByte := GP8()
bByte := GP8()
MOVB(Mem{Base: a, Index: unequalByteIndex, Scale: 1}, aByte)
MOVB(Mem{Base: b, Index: unequalByteIndex, Scale: 1}, bByte)
CMPB(aByte, bByte)
JB(LabelRef("less")

Label("greater")
Label("less")
Label("equal")

Here's what we have now — we extract two masks (a!=b and a<b) — then find the first set bit in the first mask (which represents an unequal byte) followed by testing that bit in the second mask:

ne := XMM()
lt := XMM()
VPCMPEQB(a, b, ne)
VPXOR(ne, ones, ne)
VPMINUB(a, b, lt)
VPCMPEQB(a, lt, lt)
VPAND(lt, ne, lt)

unequalMask := GP32()
lessMask := GP32()
VPMOVMSKB(ne, unequalMask)
VPMOVMSKB(lt, lessMask)

TESTL(unequalMask, unequalMask)
JZ(LabelRef("equal"))
unequalByteIndex := GP32()
BSFL(unequalMask, unequalByteIndex)
BTSL(unequalByteIndex, lessMask)
JCS(LabelRef("less"))

Label("greater")
Label("less")
Label("equal")

It's not better in the case where the inputs are the same, but it's better in almost all other cases:

name                                              old time/op    new time/op    delta
Intersect/size_16,_with_0%_chance_of_overlap-4      47.0µs ± 0%    40.7µs ± 0%  -13.30%  (p=0.008 n=5+5)
Intersect/size_16,_with_10%_chance_of_overlap-4     39.4µs ± 0%    34.4µs ± 0%  -12.52%  (p=0.008 n=5+5)
Intersect/size_16,_with_50%_chance_of_overlap-4     14.1µs ± 0%    13.3µs ± 1%   -5.73%  (p=0.008 n=5+5)
Intersect/size_16,_with_100%_chance_of_overlap-4    4.74µs ± 0%    5.18µs ± 3%   +9.27%  (p=0.008 n=5+5)
Union/size_16,_with_0%_chance_of_overlap-4          44.3µs ± 0%    39.5µs ± 0%  -10.73%  (p=0.016 n=5+4)
Union/size_16,_with_10%_chance_of_overlap-4         36.8µs ± 0%    34.1µs ± 0%   -7.50%  (p=0.008 n=5+5)
Union/size_16,_with_50%_chance_of_overlap-4         13.9µs ± 4%    14.4µs ± 2%     ~     (p=0.095 n=5+5)
Union/size_16,_with_100%_chance_of_overlap-4        4.84µs ± 3%    5.42µs ± 7%  +11.92%  (p=0.008 n=5+5)

name                                              old speed      new speed      delta
Intersect/size_16,_with_0%_chance_of_overlap-4    2.79GB/s ± 0%  3.22GB/s ± 0%  +15.34%  (p=0.008 n=5+5)
Intersect/size_16,_with_10%_chance_of_overlap-4   3.33GB/s ± 0%  3.81GB/s ± 0%  +14.31%  (p=0.008 n=5+5)
Intersect/size_16,_with_50%_chance_of_overlap-4   9.27GB/s ± 0%  9.83GB/s ± 1%   +6.08%  (p=0.008 n=5+5)
Intersect/size_16,_with_100%_chance_of_overlap-4  27.7GB/s ± 0%  25.3GB/s ± 3%   -8.46%  (p=0.008 n=5+5)
Union/size_16,_with_0%_chance_of_overlap-4        2.96GB/s ± 0%  3.32GB/s ± 0%  +12.02%  (p=0.016 n=5+4)
Union/size_16,_with_10%_chance_of_overlap-4       3.56GB/s ± 0%  3.85GB/s ± 0%   +8.11%  (p=0.008 n=5+5)
Union/size_16,_with_50%_chance_of_overlap-4       9.40GB/s ± 4%  9.11GB/s ± 2%     ~     (p=0.095 n=5+5)
Union/size_16,_with_100%_chance_of_overlap-4      27.1GB/s ± 3%  24.2GB/s ± 7%  -10.51%  (p=0.008 n=5+5)
chriso commented 3 years ago

I've tried exploring branchless versions, but it's always slower because there's no CMOV for an xmm register; you need to CMOV the pointer and load again:

diff --git i/build/sortedset/union16_asm.go w/build/sortedset/union16_asm.go
index 3085cea..ab43a17 100644
--- i/build/sortedset/union16_asm.go
+++ w/build/sortedset/union16_asm.go
@@ -33,6 +33,9 @@ func main() {

    Label("loop")

+   src := GP64()
+   MOVQ(a, src)
+
    // Compare bytes and extract two masks.
    // ne = mask of bytes where a!=b
    // lt = mask of bytes where a<b
@@ -54,26 +57,34 @@ func main() {
    unequalByteIndex := GP32()
    BSFL(unequalMask, unequalByteIndex)
    BTSL(unequalByteIndex, lessMask)
-   JCS(LabelRef("less"))

-   // If b>a, copy and advance a.
-   Label("greater")
-   VMOVUPS(bItem, Mem{Base: dst})
-   ADDQ(Imm(16), dst)
-   ADDQ(Imm(16), b)
-   CMPQ(b, bEnd)
-   JE(LabelRef("done"))
-   VMOVUPS(Mem{Base: b}, bItem)
-   JMP(LabelRef("loop"))
+   // Write from either a or b depending on which is less.
+   CMOVQCC(b, src)
+   item := XMM()
+   VMOVUPS(Mem{Base: src}, item)
+   VMOVUPS(item, Mem{Base: dst})

-   // If a<b, copy and advance a.
-   Label("less")
-   VMOVUPS(aItem, Mem{Base: dst})
+   // Conditionally advance either A or B.
+   advanceMask := GP64()
+   SBBQ(advanceMask, advanceMask) // advanceMask has CF in all bits
+   advanceA := GP64()
+   advanceB := GP64()
+   MOVQ(U32(16), advanceA)
+   MOVQ(U32(16), advanceB)
+   ANDQ(advanceMask, advanceA)
+   NOTQ(advanceMask)
+   ANDQ(advanceMask, advanceB)
+
+   // Advance pointers and loop.
+   ADDQ(advanceA, a)
+   ADDQ(advanceB, b)
    ADDQ(Imm(16), dst)
-   ADDQ(Imm(16), a)
    CMPQ(a, aEnd)
    JE(LabelRef("done"))
+   CMPQ(b, bEnd)
+   JE(LabelRef("done"))
    VMOVUPS(Mem{Base: a}, aItem)
+   VMOVUPS(Mem{Base: b}, bItem)
    JMP(LabelRef("loop"))

    // If a==b, copy either and advance both.
diff --git i/sortedset/union16_amd64.s w/sortedset/union16_amd64.s
index 185f2b7..cb17691 100644
--- i/sortedset/union16_amd64.s
+++ w/sortedset/union16_amd64.s
@@ -3,7 +3,7 @@
 #include "textflag.h"

 // func union16(dst []byte, a []byte, b []byte) (i int, j int, k int)
-// Requires: AVX
+// Requires: AVX, CMOV
 TEXT ·union16(SB), NOSPLIT, $0-96
    MOVQ     dst_base+0(FP), AX
    MOVQ     a_base+24(FP), CX
@@ -17,35 +17,38 @@ TEXT ·union16(SB), NOSPLIT, $0-96
    VMOVUPS  (DX), X2

 loop:
+   MOVQ      CX, DI
    VPCMPEQB  X1, X2, X3
    VPXOR     X3, X0, X3
-   VPMINUB   X1, X2, X4
-   VPCMPEQB  X1, X4, X4
-   VPAND     X4, X3, X4
-   VPMOVMSKB X3, DI
-   VPMOVMSKB X4, R8
-   CMPL      DI, $0x00000000
+   VPMINUB   X1, X2, X2
+   VPCMPEQB  X1, X2, X2
+   VPAND     X2, X3, X2
+   VPMOVMSKB X3, R8
+   VPMOVMSKB X2, R9
+   CMPL      R8, $0x00000000
    JE        equal
-   BSFL      DI, R9
-   BTSL      R9, R8
-   JCS       less
-   VMOVUPS   X2, (AX)
+   BSFL      R8, R10
+   BTSL      R10, R9
+   CMOVQCC   DX, DI
+   VMOVUPS   (DI), X1
+   VMOVUPS   X1, (AX)
+   SBBQ      DI, DI
+   MOVQ      $0x00000010, R8
+   MOVQ      $0x00000010, R9
+   ANDQ      DI, R8
+   NOTQ      DI
+   ANDQ      DI, R9
+   ADDQ      R8, CX
+   ADDQ      R9, DX
    ADDQ      $0x10, AX
-   ADDQ      $0x10, DX
+   CMPQ      CX, BX
+   JE        done
    CMPQ      DX, SI
    JE        done
+   VMOVUPS   (CX), X1
    VMOVUPS   (DX), X2
    JMP       loop

-less:
-   VMOVUPS X1, (AX)
-   ADDQ    $0x10, AX
-   ADDQ    $0x10, CX
-   CMPQ    CX, BX
-   JE      done
-   VMOVUPS (CX), X1
-   JMP     loop
-
 equal:
    VMOVUPS X1, (AX)
    ADDQ    $0x10, AX
name                                          old time/op    new time/op    delta
Union/size_16,_with_0%_chance_of_overlap-4      39.5µs ± 0%    47.6µs ± 0%  +20.48%  (p=0.016 n=4+5)
Union/size_16,_with_10%_chance_of_overlap-4     34.1µs ± 0%    41.6µs ± 1%  +22.22%  (p=0.008 n=5+5)
Union/size_16,_with_50%_chance_of_overlap-4     14.4µs ± 2%    22.3µs ± 1%  +54.87%  (p=0.008 n=5+5)
Union/size_16,_with_100%_chance_of_overlap-4    5.42µs ± 7%    7.10µs ± 0%  +31.12%  (p=0.008 n=5+5)

name                                          old speed      new speed      delta
Union/size_16,_with_0%_chance_of_overlap-4    3.32GB/s ± 0%  2.75GB/s ± 0%  -17.00%  (p=0.016 n=4+5)
Union/size_16,_with_10%_chance_of_overlap-4   3.85GB/s ± 0%  3.15GB/s ± 1%  -18.18%  (p=0.008 n=5+5)
Union/size_16,_with_50%_chance_of_overlap-4   9.11GB/s ± 2%  5.88GB/s ± 1%  -35.43%  (p=0.008 n=5+5)
Union/size_16,_with_100%_chance_of_overlap-4  24.2GB/s ± 7%  18.5GB/s ± 0%  -23.88%  (p=0.008 n=5+5)