dotnet / runtime

.NET is a cross-platform runtime for cloud, mobile, desktop, and IoT apps.
https://docs.microsoft.com/dotnet/core/
MIT License
14.98k stars 4.66k forks source link

Improve Vector128.ExtractMostSignificantBits for arm64 #76047

Open EgorBo opened 1 year ago

EgorBo commented 1 year ago

Per discussion with @tannergooding on Discord

Vector128.ExtractMostSignificantBits is quite an important API that is typically used together with comparisons and TrailingZeroCount/LeadingZeroCount to detect positions of an element in a vector - typically used in various IndexOf-like algorithms, etc. Example:

static void PrintPostion()
{
    Vector128<byte> src = Vector128.Create((byte)0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
    Vector128<byte> val = Vector128.Create((byte)42);

    // prints 3 as the index of 42 is "3" in src vector
    Console.WriteLine(FirstMatch(src, val)); 
}

static int FirstMatch(Vector128<byte> src, Vector128<byte> val)
{
    Vector128<byte> eq = Vector128.Equals(src, val);
    return BitOperations.TrailingZeroCount(eq.ExtractMostSignificantBits());
}

Codegen for FirstMatch on x64:

       vpcmpeqb  xmm0, xmm0, xmm1
       vpmovmskb eax, xmm0
       tzcnt     eax, eax

Codegen for FirstMatch on arm64:

            cmeq    v16.16b, v0.16b, v1.16b
            ldr     q17, [@RWD00]
            and     v16.16b, v16.16b, v17.16b
            ldr     q17, [@RWD16]
            ushl    v16.16b, v16.16b, v17.16b
            movi    v17.4s, #0
            ext     v17.16b, v16.16b, v17.16b, #8
            addv    b17, v17.8b
            umov    w0, v17.b[0]
            lsl     w0, w0, #8
            addv    b16, v16.8b
            umov    w1, v16.b[0]
            orr     w0, w0, w1
            rbit    w0, w0
            clz     w0, w0

Because arm64 doesn't have a direct equivalent of movmsk. However, this particular case can be optimized because we know that input of ExtractMostSignificantBits is a comparison's result with all elements being either zero or all-bits-set, in the best case we can perform this smart trick from arm blog by @danlark1

            cmeq    v16.16b, v0.16b, v1.16b
            shrn    v16.8b, v16.8h, #4
            umov    x0, v16.d[0]
            rbit    x0, x0
            clz     x0, x0
            asr     w0, w0, #2

its C# equivalent:

static int FirstMatch(Vector128<byte> src, Vector128<byte> val)
{
    Vector128<byte> eq = Vector128.Equals(vector, val);
    ulong matches = AdvSimd.ShiftRightLogicalNarrowingLower(src.AsUInt16(), 4).AsUInt64().ToScalar();
    return BitOperations.TrailingZeroCount(matches) >> 2;
}

Performance impact

We expect a nice improvement for small inputs like what http parsing typically sees where we have to find positions of symbols like :, \n etc in relatively small inputs. For large inputs in most of our algorithms we use fast compare == Vector128<>.Zero checks to ignore chunks without matches.

category:cq theme:vector-codegen skill-level:intermediate cost:small impact:small

ghost commented 1 year ago

Tagging subscribers to this area: @JulieLeeMSFT, @jakobbotsch See info in area-owners.md if you want to be subscribed.

Issue Details
_Per discussion with @tannergooding on Discord_ `Vector128.ExtractMostSignificantBits` is quite an important API that is typically used together with comparisons and `TrailingZeroCount/LeadingZeroCount` to detect positions of an element in a vector - typically used in various `IndexOf`-like algorithms, etc. Example: ```csharp static void PrintPostion() { Vector128 src = Vector128.Create((byte)0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); Vector128 val = Vector128.Create((byte)42); // prints 3 as the index of 42 is "3" in src vector Console.WriteLine(FirstMatch(src, val)); } static int FirstMatch(Vector128 src, Vector128 val) { Vector128 eq = Vector128.Equals(src, val); return BitOperations.TrailingZeroCount(eq.ExtractMostSignificantBits()); } ``` Codegen for `FirstMatch` on x64: ```asm vmovupd xmm0, xmmword ptr [rcx] vpcmpeqb xmm0, xmm0, xmmword ptr [rdx] vpmovmskb eax, xmm0 tzcnt eax, eax ``` Codegen for `FirstMatch` on arm64: ```asm cmeq v16.16b, v0.16b, v1.16b ldr q17, [@RWD00] and v16.16b, v16.16b, v17.16b ldr q17, [@RWD16] ushl v16.16b, v16.16b, v17.16b movi v17.4s, #0 ext v17.16b, v16.16b, v17.16b, #8 addv b17, v17.8b umov w0, v17.b[0] lsl w0, w0, #8 addv b16, v16.8b umov w1, v16.b[0] orr w0, w0, w1 rbit w0, w0 clz w0, w0 ``` Because arm64 doesn't have a direct equivalent of `movmsk`. However, this particular case can be optimized because we know that input of `ExtractMostSignificantBits` is a comparison's result with all elements being either zero or all-bits-set, in the best case we can perform this smart trick from [arm blog](https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon) by @danlark1 ```asm cmeq v16.16b, v0.16b, v1.16b shrn v16.8b, v16.8h, #4 umov x0, v16.d[0] rbit x0, x0 clz x0, x0 asr w0, w0, #2 ``` its C# equivalent: ```csharp static int FirstMatch(Vector128 src, Vector128 val) { Vector128 eq = Vector128.Equals(vector, val); ulong matches = AdvSimd.ShiftRightLogicalNarrowingLower(src.AsUInt16(), 4).AsUInt64().ToScalar(); return BitOperations.TrailingZeroCount(matches) >> 2; } ``` ### Performance impact We expect a nice improvement for small inputs like what http parsing typically sees where we have to find positions of symbols like `:`, `\n` etc in relatively small inputs. For large inputs in most of our algorithms we use fast `compare == Vector128<>.Zero` checks to ignore chunks without matches.
Author: EgorBo
Assignees: -
Labels: `area-CodeGen-coreclr`, `untriaged`
Milestone: -
EgorBo commented 1 year ago

Benchmark:

    private static readonly byte[] httpHeader = Encoding.UTF8.GetBytes( 
""" 
Host: 127.0.0.1:5001 
Connection: keep-alive 
Cache-Control: max-age=0 
sec-ch-ua-mobile: ?0 
sec-ch-ua-platform: "Windows" 
Upgrade-Insecure-Requests: 1 
Sec-Fetch-Site: none 
Sec-Fetch-Mode: navigate 
Sec-Fetch-User: ?1 
Sec-Fetch-Dest: document 
Accept-Encoding: gzip, deflate, br 
Accept-Language: en-US,en;q=0.9 
"""); 

    [Benchmark] 
    public int CountHeaders() 
    { 
        ReadOnlySpan<byte> span = httpHeader.AsSpan(); 
        int newline = 0; 
        int count = 0; 
        while (newline != -1 && span.Length > newline) 
        { 
            span = span.Slice(newline + 1); 
            newline = span.IndexOfAny((byte)'\n', (byte)':'); // or just IndexOf((byte)'\n')
            count++; 
        } 
        return count; 
    } 
Method Job Toolchain Mean Error StdDev Ratio
CountHeaders Job-AWWEJK /Core_Root/corerun 193.7 ns 2.58 ns 2.41 ns 0.80
CountHeaders Job-DZXKTO /Core_Root_base/corerun 242.2 ns 1.11 ns 1.04 ns 1.00
EgorBo commented 1 year ago

So the task here is to optimize TrailingZeroCount(comparison.ExtractMostSignificantBits()) in:

static int FirstMatch_old(Vector128<byte> src, Vector128<byte> val)
{
    Vector128<byte> eq = Vector128.Equals(src, val);
    return BitOperations.TrailingZeroCount(eq.ExtractMostSignificantBits());
}

to emit the same codegen as this function does:

static int FirstMatch_new(Vector128<byte> src, Vector128<byte> val)
{
    Vector128<byte> eq = Vector128.Equals(src, val);
    ulong matches = AdvSimd.ShiftRightLogicalNarrowingLower(src.AsUInt16(), 4).AsUInt64().ToScalar();
    return BitOperations.TrailingZeroCount(matches) >> 2;
}

Very first task here is to move ExtractMostSignificantBits expansion from importer to lower. Currently, ExtractMostSignificantBits on ARM64 is imported as the following IR:

STMT00000 ( 0x000[E-] ... ??? )
               [000006] -A---------                         *  ASG       simd16 (copy)
               [000005] D------N---                         +--*  LCL_VAR   simd16<System.Runtime.Intrinsics.Vector128`1> V03 tmp1         
               [000004] -----------                         \--*  HWINTRINSIC simd16 ubyte ShiftLogical
               [000003] -----------                            +--*  HWINTRINSIC simd16 ubyte And
               [000000] -----------                            |  +--*  LCL_VAR   simd16<System.Runtime.Intrinsics.Vector128`1> V00 arg0         
               [000001] -----------                            |  \--*  CNS_VEC   simd16<0x80808080, 0x80808080, 0x80808080, 0x80808080>
               [000002] -----------                            \--*  CNS_VEC   simd16<0xfcfbfaf9, 0x00fffefd, 0xfcfbfaf9, 0x00fffefd>

    [ 1]   6 (0x006) ret

STMT00001 ( ??? ... ??? )
               [000023] -----------                         *  RETURN    int   
               [000022] -----------                         \--*  OR        int   
               [000012] ---------U-                            +--*  CAST      int <- uint
               [000011] -----------                            |  \--*  HWINTRINSIC ubyte  ubyte ToScalar
               [000010] -----------                            |     \--*  HWINTRINSIC simd8  ubyte AddAcross
               [000009] -----------                            |        \--*  HWINTRINSIC simd8  ubyte GetLower
               [000008] -----------                            |           \--*  LCL_VAR   simd16<System.Runtime.Intrinsics.Vector128`1> V03 tmp1         
               [000021] -----------                            \--*  LSH       int   
               [000019] ---------U-                               +--*  CAST      int <- uint
               [000018] -----------                               |  \--*  HWINTRINSIC ubyte  ubyte ToScalar
               [000017] -----------                               |     \--*  HWINTRINSIC simd8  ubyte AddAcross
               [000016] -----------                               |        \--*  HWINTRINSIC simd8  ubyte GetLower
               [000015] -----------                               |           \--*  HWINTRINSIC simd16 ubyte ExtractVector128
               [000007] -----------                               |              +--*  LCL_VAR   simd16<System.Runtime.Intrinsics.Vector128`1> V03 tmp1         
               [000013] -----------                               |              +--*  CNS_VEC   simd16<0x00000000, 0x00000000, 0x00000000, 0x00000000>
               [000014] -----------                               |              \--*  CNS_INT   int    8
               [000020] -----------                               \--*  CNS_INT   int    8

Instead, it should be just * HWINTRINSIC simd16 int ExtractMostSignificantBits as is and is expanded to ^ in Lower. It will allow us to perform optimizations in VN on top of a single node instead of this huge tree. Here is where it's expanded currently: https://github.com/dotnet/runtime/blob/91df18483b2946bed8fed09ed886f888d165a005/src/coreclr/jit/hwintrinsicarm64.cpp#L943-L1122 - it should be moved to Lower. Then the next step is simple to recognize a the final shape we want to opimize in VN.

cc @JulieLeeMSFT