dotnet / runtime

.NET is a cross-platform runtime for cloud, mobile, desktop, and IoT apps.
https://docs.microsoft.com/dotnet/core/
MIT License
14.59k stars 4.55k forks source link

JIT: Add support for strength reduction #100913

Open jakobbotsch opened 3 months ago

jakobbotsch commented 3 months ago

Now that we have an SSA based IV analysis (added in #97865) we should implement strength reduction based on it. Example loop:

[MethodImpl(MethodImplOptions.NoInlining)]
private static int StrengthReduce(Span<int> s)
{
    int sum = 0;
    foreach (int val in s)
        sum += val;

    return sum;
}

Codegen x64:

       xor      r8d, r8d
       test     ecx, ecx
       jle      SHORT G_M11380_IG04
       align    [0 bytes for IG03]
                        ;; size=15 bbWeight=1 PerfScore 5.75

G_M11380_IG03:  ;; offset=0x0013
       add      eax, dword ptr [rdx+4*r8]
       inc      r8d
       cmp      r8d, ecx
       jl       SHORT G_M11380_IG03
                        ;; size=12 bbWeight=4 PerfScore 18.00

Codegen arm64:

            mov     w3, wzr
            cmp     w2, #0
            ble     G_M1017_IG04
            align   [0 bytes for IG03]
                        ;; size=24 bbWeight=1 PerfScore 6.50

G_M1017_IG03:  ;; offset=0x0024
            ldr     w4, [x1, w3, UXTW #2]
            add     w0, w4, w0
            add     w3, w3, #1
            cmp     w3, w2
            blt     G_M1017_IG03
                        ;; size=20 bbWeight=4 PerfScore 22.00

The point of strength reduction is to optimize the loop codegen as if it had been written as follows:

[MethodImpl(MethodImplOptions.NoInlining)]
private static int StrengthReduce(Span<int> s)
{
    int sum = 0;
    ref int p = ref MemoryMarshal.GetReference(s);
    ref int end = ref Unsafe.Add(ref p, s.Length);
    while (Unsafe.IsAddressLessThan(ref p, ref end))
    {
        sum += p;
        p = ref Unsafe.Add(ref p, 1);
    }

    return sum;
}

The codegen would look like: x64:

       xor      eax, eax
       mov      rdx, bword ptr [rcx]
       mov      ecx, dword ptr [rcx+0x08]
       lea      rcx, bword ptr [rdx+4*rcx]
       cmp      rdx, rcx
       jae      SHORT G_M11380_IG04
       align    [0 bytes for IG03]
                        ;; size=17 bbWeight=1 PerfScore 6.00

G_M11380_IG03:  ;; offset=0x0011
       add      eax, dword ptr [rdx]
       add      rdx, 4
       cmp      rdx, rcx
       jb       SHORT G_M11380_IG03
                        ;; size=11 bbWeight=4 PerfScore 18.00

arm64:

            mov     w0, wzr
            ldr     x1, [fp, #0x10] // [V00 arg0]
            ldr     w2, [fp, #0x18] // [V00 arg0+0x08]
            ubfiz   x2, x2, #2, #32
            add     x2, x1, x2
            cmp     x1, x2
            bhs     G_M11380_IG04
            align   [0 bytes for IG03]
                        ;; size=28 bbWeight=1 PerfScore 7.50

G_M11380_IG03:  ;; offset=0x0028
            ldr     w3, [x1]
            add     w0, w0, w3
            add     x1, x1, #4
            cmp     x1, x2
            blo     G_M11380_IG03
                        ;; size=20 bbWeight=4 PerfScore 22.00

For arm64 there is the additional possibility of using post-increment addressing mode by optimizing the placement of the IV increment once the strength reduction has happened. The loop body is then reducible to:

G_M11380_IG03:  ;; offset=0x0028
            ldr     w3, [x1], #4
            add     w0, w0, w3
            cmp     x1, x2
            blo     G_M11380_IG03
dotnet-policy-service[bot] commented 3 months ago

Tagging subscribers to this area: @JulieLeeMSFT, @jakobbotsch See info in area-owners.md if you want to be subscribed.

jakobbotsch commented 2 months ago

There is a question of whether we can optimize Span<T> as well as T[] without introducing (more) special status for Span<T>/ReadOnlySpan<T>. That's because the transformation shown above is actually illegal for the JIT to do unless we make it undefined behavior for a Span<T> to exist with an "invalid" range of managed byrefs.

Consider the following example:

static void Main()
{
    int[] values = [1, 2, 3, 4, 0];
    Span<int> exampleSpan = MemoryMarshal.CreateSpan(ref values[0], int.MaxValue);
    Sum(exampleSpan); // No problem today
    Sum2(exampleSpan); // Forms illegal byref
}

private static int Sum(Span<int> s)
{
    int sum = 0;
    foreach (int x in s)
    {
        if (x == 0)
            break;

        sum += x;
    }

    return sum;
}

private static int Sum2(Span<int> s)
{
    int sum = 0;
    ref int p = ref MemoryMarshal.GetReference(s);
    ref int end = ref Unsafe.Add(ref p, s.Length);
    while (Unsafe.IsAddressLessThan(ref p, ref end))
    {
        int x = p;
        if (x == 0)
            break;

        sum += x;
        p = ref Unsafe.Add(ref p, 1);
    }

    return sum;
}

exampleSpan is created with a valid byref but a length that makes _reference + length an invalid byref. Today, there is no problem in Sum because we do not eagerly form the _reference + length byref, but Sum2 ends up eagerly forming this illegal byref. The strength reduction optimization would have the JIT transform Sum to Sum2.

@jkotas @davidwrighton any thoughts on this? Can we document somewhere that Span<T>/ReadOnlySpan<T> have "special status" to make them amenable to optimizations to a similar level to T[]? I think we would document two things:

  1. Non-negative length field. The JIT already makes use of this assumption today.
  2. Requirements on the range of byrefs represented by the Span<T>, i.e. _reference + length must point inside (or at the end of) the same object as _reference when it is a managed byref.
jkotas commented 2 months ago

The existing Span uses do not always follow this restriction. For example:

https://github.com/dotnet/runtime/blob/81ca1c4b1e1eea9c94bdeb38c050d5c4063bab57/src/libraries/System.Runtime/tests/System.Runtime.Extensions.Tests/System/Convert.ToHexString.cs#L96-L97

I guess we can document it retroactively as a breaking change and try to fix all instances of the bad patterns that we can find.

jakobbotsch commented 2 months ago

Hmm, I'll have to see if that seems to be worth it once I get further. I can start out with arrays for now to do the measurements.

jakobbotsch commented 2 months ago

I think instead of forming end = span._reference + span.length * size, we can just utilize a reverse counted loop and come out equal on x64/arm64. For example, Sum2 will usually end up as

private static int Sum2(Span<int> s)
{
    int sum = 0;
    ref int p = ref MemoryMarshal.GetReference(s);
    if (s.Length > 0)
    {
      int length = s.Length;
      do
      {
          int x = p;
          if (x == 0)
              break;

          sum += x;
          p = ref Unsafe.Add(ref p, 1);
      } while (--length > 0);
    }

      return sum;
}

when loop inversion is kicking in. The --length > 0 can be done in 2 instructions + 1 live variable on arm64/x64, exactly the same as if we had formed end.

jakobbotsch commented 1 week ago

We sadly still have the problem described above for Span<T>. Without the assumption that a Span<T> points within the same managed object it is illegal to transform

public static int Sum(Span<int> span, Func<int, bool> sumIndex)
{
    for (int i = 0; i < span.Length; i++)
      sum += sumIndex(i) ? span[i] : 0;
    return sum;
}

into

public static int Sum(Span<int> span, Func<int, bool> sumIndex)
{
    ref int val = ref span[0];
    for (int i = 0; i < span.Length; i++)
    {
      sum += sumIndex(i) ? val : 0;
      val = ref Unsafe.Add(ref val, 1);
    }
    return sum;
}

The same transformation seems ok for arrays.

(Of course whether or not this transformation is profitable is another question entirely.)