dotnet / runtime

.NET is a cross-platform runtime for cloud, mobile, desktop, and IoT apps.
https://docs.microsoft.com/dotnet/core/
MIT License
15.28k stars 4.74k forks source link

[API Proposal]: ARM64 SVE: Add additional types for the mask APIs #108233

Open a74nh opened 1 month ago

a74nh commented 1 month ago

Background and motivation

Consider this function to Multiply two float arrays together and then sum the result.

The three casts to Vector<float> exist because there are no float versions CreateWhileLessThanMaskX() and TestFirstTrue()

    public static unsafe float fmla(ref float* a, ref float* b, int length)
    {
      Vector<float> res = Vector<float>.Zero;
      Vector<uint> ptrue = Sve.CreateTrueMaskUInt32();
      Vector<uint> ploop;

      for (int i = 0; Sve.TestFirstTrue(ptrue, ploop = Sve.CreateWhileLessThanMask32Bit(i, length)); i+= (int)Sve.Count32BitElements())
      {
        Vector<float> a_vec = Sve.LoadVector((Vector<float>)ploop, a+i);
        Vector<float> b_vec = Sve.LoadVector((Vector<float>)ploop, b+i);
        res = Sve.ConditionalSelect((Vector<float>)ploop, Sve.FusedMultiplyAdd(res, a_vec, b_vec), res);
      }

      return Sve.AddAcross(res).ToScalar();
    }

The following code looks much nicer and is easier to logically parse:

    public static unsafe float fmla(ref float* a, ref float* b, int length)
    {
      Vector<float> res = Vector<float>.Zero;
      Vector<float> ptrue = Sve.CreateTrueMaskFloat();
      Vector<float> ploop;

      for (int i = 0; Sve.TestFirstTrue(ptrue, ploop = Sve.CreateWhileLessThanMaskFloat(i, length)); i+= (int)Sve.Count32BitElements())
      {
        Vector<float> a_vec = Sve.LoadVector(ploop, a+i);
        Vector<float> b_vec = Sve.LoadVector(ploop, b+i);
        res = Sve.ConditionalSelect(ploop, Sve.FusedMultiplyAdd(res, a_vec, b_vec), res);
      }

      return Sve.AddAcross(res).ToScalar();
    }

The same would then apply for other SVE APIs operating on a mask.

Also consider:

public static unsafe long multiplyAdd(ref short* a, ref short* b, int length)
{
  Vector<short> res = Vector<short>.Zero;
  Vector<short> ploop;

  for (int i = 0;
       Sve.TestFirstTrue(Sve.CreateTrueMaskInt16(), ploop = (Vector<short>)Sve.CreateWhileLessThanMask16Bit(i, length));
       i+= (int)Sve.Count16BitElements())
  {
    Vector<short> a_vec = Sve.LoadVector((Vector<short>)ploop, a+i);
    Vector<short> b_vec = Sve.LoadVector((Vector<short>)ploop, b+i);
    res = Sve.ConditionalSelect((Vector<short>)ploop, Sve.MultiplyAdd(res, a_vec, b_vec), res);
  }

  return Sve.AddAcross(res).ToScalar();
}

For the for loop we need a 16bit whilelt mask. The only way to create this is via CreateWhileLessThanMask16Bit(), but this returns a Vector<ushort>. It needs to be a Vector<short> so that it can be used in the conditionalSelect().

The casting to Vector<short> is a little confusing.

I suggest we add signed versions of CreateWhileLessThanMask()

API Proposal

Using the same T syntax as other SVE proposals. These are all extensions of existing API methods

APIs to add:

namespace System.Runtime.Intrinsics.Arm;

public partial class Sve
{

  /// T: float, double
  public static unsafe Vector<T> CreateBreakAfterMask(Vector<T> totalMask, Vector<T> fromMask); // BRKA // predicated

  /// T: float, double
  public static unsafe Vector<T> CreateBreakAfterPropagateMask(Vector<T> mask, Vector<T> left, Vector<T> right); // BRKPA

  /// T: float, double
  public static unsafe Vector<T> CreateBreakBeforeMask(Vector<T> totalMask, Vector<T> fromMask); // BRKB // predicated

  /// T: float, double
  public static unsafe Vector<T> CreateBreakBeforePropagateMask(Vector<T> mask, Vector<T> left, Vector<T> right); // BRKPB

  /// T: float, double
  public static unsafe Vector<T> CreateBreakPropagateMask(Vector<T> totalMask, Vector<T> fromMask); // BRKN // predicated

  /// T: float, double
  public static unsafe Vector<T> CreateMaskForFirstActiveElement(Vector<T> totalMask, Vector<T> fromMask); // PFIRST

  /// T: float, double
  public static unsafe bool TestAnyTrue(Vector<T> leftMask, Vector<T> rightMask); // PTEST

  /// T: float, double
  public static unsafe bool TestFirstTrue(Vector<T> leftMask, Vector<T> rightMask); // PTEST

  /// T: float, double
  public static unsafe bool TestLastTrue(Vector<T> leftMask, Vector<T> rightMask); // PTEST

  public static Vector<float> CreateWhileLessThanMaskSingle(int left, int right);
  public static Vector<float> CreateWhileLessThanMaskSingle(long left, long right);
  public static Vector<float> CreateWhileLessThanMaskSingle(uint left, uint right);
  public static Vector<float> CreateWhileLessThanMaskSingle(ulong left, ulong right);

  public static Vector<double> CreateWhileLessThanMaskDouble(int left, int right);
  public static Vector<double> CreateWhileLessThanMaskDouble(long left, long right);
  public static Vector<double> CreateWhileLessThanMaskDouble(uint left, uint right);
  public static Vector<double> CreateWhileLessThanMaskDouble(ulong left, ulong right);

  public static unsafe Vector<short> CreateWhileLessThanMaskInt16(int left, int right); // WHILELT
  public static unsafe Vector<short> CreateWhileLessThanMaskInt16(long left, long right); // WHILELT
  public static unsafe Vector<short> CreateWhileLessThanMaskInt16(uint left, uint right); // WHILELO
  public static unsafe Vector<short> CreateWhileLessThanMaskInt16(ulong left, ulong right); // WHILELO

  public static unsafe Vector<int> CreateWhileLessThanMaskInt32(int left, int right); // WHILELT
  public static unsafe Vector<int> CreateWhileLessThanMaskInt32(long left, long right); // WHILELT
  public static unsafe Vector<int> CreateWhileLessThanMaskInt32(uint left, uint right); // WHILELO
  public static unsafe Vector<int> CreateWhileLessThanMaskInt32(ulong left, ulong right); // WHILELO

  public static unsafe Vector<long> CreateWhileLessThanMaskInt64(int left, int right); // WHILELT
  public static unsafe Vector<long> CreateWhileLessThanMaskInt64(long left, long right); // WHILELT
  public static unsafe Vector<long> CreateWhileLessThanMaskInt64(uint left, uint right); // WHILELO
  public static unsafe Vector<long> CreateWhileLessThanMaskInt64(ulong left, ulong right); // WHILELO

  public static unsafe Vector<sbyte> CreateWhileLessThanMaskInt8(int left, int right); // WHILELT
  public static unsafe Vector<sbyte> CreateWhileLessThanMaskInt8(long left, long right); // WHILELT
  public static unsafe Vector<sbyte> CreateWhileLessThanMaskInt8(uint left, uint right); // WHILELO
  public static unsafe Vector<sbyte> CreateWhileLessThanMaskInt8(ulong left, ulong right); // WHILELO

  public static unsafe Vector<ushort> CreateWhileLessThanMaskUInt16(int left, int right); // WHILELT
  public static unsafe Vector<ushort> CreateWhileLessThanMaskUInt16(long left, long right); // WHILELT
  public static unsafe Vector<ushort> CreateWhileLessThanMaskUInt16(uint left, uint right); // WHILELO
  public static unsafe Vector<ushort> CreateWhileLessThanMaskUInt16(ulong left, ulong right); // WHILELO

  public static unsafe Vector<uint> CreateWhileLessThanMaskUInt32(int left, int right); // WHILELT
  public static unsafe Vector<uint> CreateWhileLessThanMaskUInt32(long left, long right); // WHILELT
  public static unsafe Vector<uint> CreateWhileLessThanMaskUInt32(uint left, uint right); // WHILELO
  public static unsafe Vector<uint> CreateWhileLessThanMaskUInt32(ulong left, ulong right); // WHILELO

  public static unsafe Vector<ulong> CreateWhileLessThanMaskUInt64(int left, int right); // WHILELT
  public static unsafe Vector<ulong> CreateWhileLessThanMaskUInt64(long left, long right); // WHILELT
  public static unsafe Vector<ulong> CreateWhileLessThanMaskUInt64(uint left, uint right); // WHILELO
  public static unsafe Vector<ulong> CreateWhileLessThanMaskUInt64(ulong left, ulong right); // WHILELO

  public static unsafe Vector<byte> CreateWhileLessThanMaskUInt8(int left, int right); // WHILELT
  public static unsafe Vector<byte> CreateWhileLessThanMaskUInt8(long left, long right); // WHILELT
  public static unsafe Vector<byte> CreateWhileLessThanMaskUInt8(uint left, uint right); // WHILELO
  public static unsafe Vector<byte> CreateWhileLessThanMaskUInt8(ulong left, ulong right); // WHILELO

  public static Vector<float> CreateWhileLessThanOrEqualMaskSingle(int left, int right);
  public static Vector<float> CreateWhileLessThanOrEqualMaskSingle(long left, long right);
  public static Vector<float> CreateWhileLessThanOrEqualMaskSingle(uint left, uint right);
  public static Vector<float> CreateWhileLessThanOrEqualMaskSingle(ulong left, ulong right);

  public static Vector<double> CreateWhileLessThanOrEqualMaskDouble(int left, int right);
  public static Vector<double> CreateWhileLessThanOrEqualMaskDouble(long left, long right);
  public static Vector<double> CreateWhileLessThanOrEqualMaskDouble(uint left, uint right);
  public static Vector<double> CreateWhileLessThanOrEqualMaskDouble(ulong left, ulong right);

  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskUInt16(int left, int right); // WHILELE
  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskUInt16(long left, long right); // WHILELE
  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskUInt16(uint left, uint right); // WHILELS
  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskUInt16(ulong left, ulong right); // WHILELS

  public static unsafe Vector<short> CreateWhileLessThanOrEqualMaskInt16(int left, int right); // WHILELE
  public static unsafe Vector<short> CreateWhileLessThanOrEqualMaskInt16(long left, long right); // WHILELE
  public static unsafe Vector<short> CreateWhileLessThanOrEqualMaskInt16(uint left, uint right); // WHILELS
  public static unsafe Vector<short> CreateWhileLessThanOrEqualMaskInt16(ulong left, ulong right); // WHILELS

  public static unsafe Vector<int> CreateWhileLessThanOrEqualMaskInt32(int left, int right); // WHILELE
  public static unsafe Vector<int> CreateWhileLessThanOrEqualMaskInt32(long left, long right); // WHILELE
  public static unsafe Vector<int> CreateWhileLessThanOrEqualMaskInt32(uint left, uint right); // WHILELS
  public static unsafe Vector<int> CreateWhileLessThanOrEqualMaskInt32(ulong left, ulong right); // WHILELS

  public static unsafe Vector<long> CreateWhileLessThanOrEqualMaskInt64(int left, int right); // WHILELE
  public static unsafe Vector<long> CreateWhileLessThanOrEqualMaskInt64(long left, long right); // WHILELE
  public static unsafe Vector<long> CreateWhileLessThanOrEqualMaskInt64(uint left, uint right); // WHILELS
  public static unsafe Vector<long> CreateWhileLessThanOrEqualMaskInt64(ulong left, ulong right); // WHILELS

  public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMaskInt8(int left, int right); // WHILELE
  public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMaskInt8(long left, long right); // WHILELE
  public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMaskInt8(uint left, uint right); // WHILELS
  public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMaskInt8(ulong left, ulong right); // WHILELS

  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskUInt32(int left, int right); // WHILELE
  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskUInt32(long left, long right); // WHILELE
  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskUInt32(uint left, uint right); // WHILELS
  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskUInt32(ulong left, ulong right); // WHILELS

  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskUInt64(int left, int right); // WHILELE
  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskUInt64(long left, long right); // WHILELE
  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskUInt64(uint left, uint right); // WHILELS
  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskUInt64(ulong left, ulong right); // WHILELS

  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskUInt8(int left, int right); // WHILELE
  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskUInt8(long left, long right); // WHILELE
  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskUInt8(uint left, uint right); // WHILELS
  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskUInt8(ulong left, ulong right); // WHILELS

  public static unsafe Vector<T> GetFfrSingle(); // RDFFR // predicated
  public static unsafe Vector<T> GetFfrDouble(); // RDFFR // predicated

  /// T: float, double
  public static unsafe void SetFfr(Vector<T> value); // WRFFR

}

APIs to remove:


namespace System.Runtime.Intrinsics.Arm;

public partial class Sve
{
  // Change of existing APIs to return signed vector instead of unsigned

  public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(int left, int right); // WHILELT
  public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(long left, long right); // WHILELT
  public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(uint left, uint right); // WHILELO
  public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(int left, int right); // WHILELT
  public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(long left, long right); // WHILELT
  public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(uint left, uint right); // WHILELO
  public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(int left, int right); // WHILELT
  public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(long left, long right); // WHILELT
  public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(uint left, uint right); // WHILELO
  public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(int left, int right); // WHILELT
  public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(long left, long right); // WHILELT
  public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(uint left, uint right); // WHILELO
  public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(int left, int right); // WHILELE
  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(long left, long right); // WHILELE
  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(uint left, uint right); // WHILELS
  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(ulong left, ulong right); // WHILELS

  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(int left, int right); // WHILELE
  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(long left, long right); // WHILELE
  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(uint left, uint right); // WHILELS
  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(ulong left, ulong right); // WHILELS

  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(int left, int right); // WHILELE
  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(long left, long right); // WHILELE
  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(uint left, uint right); // WHILELS
  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(ulong left, ulong right); // WHILELS

  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(int left, int right); // WHILELE
  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(long left, long right); // WHILELE
  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(uint left, uint right); // WHILELS
  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(ulong left, ulong right); // WHILELS
}
dotnet-policy-service[bot] commented 1 month ago

Tagging subscribers to this area: @dotnet/area-system-numerics See info in area-owners.md if you want to be subscribed.

tannergooding commented 1 month ago

API proposals need to follow the recommended template: https://github.com/dotnet/runtime/blob/main/docs/project/api-review-process.md#steps

That includes explicitly listing the APIs that are desired to be exposed in a code block, such as:

namespace System.Runtime.Intrinsics.Arm;

public partial class Sve
{
    public static Vector<float> CreateWhileLessThanMaskSingle(int left, int right);
    public static Vector<float> CreateWhileLessThanMaskSingle(long left, long right);
    public static Vector<float> CreateWhileLessThanMaskSingle(uint left, uint right);
    public static Vector<float> CreateWhileLessThanMaskSingle(ulong left, ulong right);

    // Rest of the APIs to be exposed listed here
}

Noting that we use Single and Double (matching System.Single and System.Double) not Float, when exposing floating-point related APIs that need to be disambiguated by name

a74nh commented 1 month ago

API proposals need to follow the recommended template: https://github.com/dotnet/runtime/blob/main/docs/project/api-review-process.md#steps

Updated as suggested in the top comment.

a74nh commented 1 month ago

Added signed/unsigned versions of CreateWhileLessThanMask as suggested in #108384

tannergooding commented 1 month ago

nit: The Int8 ones should be SByte and the UInt8 ones should be Byte.

In general the rule is that we use the official type names rather than language specific keywords (such as byte or long) in API signatures. Byte and SByte are two where the official type name matches the C# keyword name and where the official type name doesn't follow the same convention as used for other integer types (i.e it isn't Int8 and UInt8)