dotnet / runtime

.NET is a cross-platform runtime for cloud, mobile, desktop, and IoT apps.
https://docs.microsoft.com/dotnet/core/
MIT License
15.47k stars 4.76k forks source link

[API Proposal]: ARM64 SVE: Add signed versions of the CreateWhileLessThanMask APIs #108384

Closed a74nh closed 2 months ago

a74nh commented 2 months ago

Background and motivation

Consider:

public static unsafe long multiplyAdd(ref short* a, ref short* b, int length)
{
  Vector<short> res = Vector<short>.Zero;
  Vector<short> ploop;

  for (int i = 0;
       Sve.TestFirstTrue(Sve.CreateTrueMaskInt16(), ploop = (Vector<short>)Sve.CreateWhileLessThanMask16Bit(i, length));
       i+= (int)Sve.Count16BitElements())
  {
    Vector<short> a_vec = Sve.LoadVector((Vector<short>)ploop, a+i);
    Vector<short> b_vec = Sve.LoadVector((Vector<short>)ploop, b+i);
    res = Sve.ConditionalSelect((Vector<short>)ploop, Sve.MultiplyAdd(res, a_vec, b_vec), res);
  }

  return Sve.AddAcross(res).ToScalar();
}

For the for loop we need a 16bit whilelt mask. The only way to create this is via CreateWhileLessThanMask16Bit(), but this returns a Vector<ushort>. It needs to be a Vector<short> so that it can be used in the conditionalSelect().

The casting to Vector<short> is a little confusing.

I suggest we add signed versions of CreateWhileLessThanMask()

API Proposal

namespace System.Runtime.Intrinsics.Arm;

public partial class Sve
{
  // Change of existing APIs to return signed vector instead of unsigned

  public static unsafe Vector<short> CreateWhileLessThanMask16Bit(int left, int right); // WHILELT

  public static unsafe Vector<short> CreateWhileLessThanMask16Bit(long left, long right); // WHILELT

  public static unsafe Vector<short> CreateWhileLessThanMask16Bit(uint left, uint right); // WHILELO

  public static unsafe Vector<short> CreateWhileLessThanMask16Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<int> CreateWhileLessThanMask32Bit(int left, int right); // WHILELT

  public static unsafe Vector<int> CreateWhileLessThanMask32Bit(long left, long right); // WHILELT

  public static unsafe Vector<int> CreateWhileLessThanMask32Bit(uint left, uint right); // WHILELO

  public static unsafe Vector<int> CreateWhileLessThanMask32Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<long> CreateWhileLessThanMask64Bit(int left, int right); // WHILELT

  public static unsafe Vector<long> CreateWhileLessThanMask64Bit(long left, long right); // WHILELT

  public static unsafe Vector<long> CreateWhileLessThanMask64Bit(uint left, uint right); // WHILELO

  public static unsafe Vector<long> CreateWhileLessThanMask64Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<sbyte> CreateWhileLessThanMask8Bit(int left, int right); // WHILELT

  public static unsafe Vector<sbyte> CreateWhileLessThanMask8Bit(long left, long right); // WHILELT

  public static unsafe Vector<sbyte> CreateWhileLessThanMask8Bit(uint left, uint right); // WHILELO

  public static unsafe Vector<sbyte> CreateWhileLessThanMask8Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<short> CreateWhileLessThanOrEqualMask16Bit(int left, int right); // WHILELE

  public static unsafe Vector<short> CreateWhileLessThanOrEqualMask16Bit(long left, long right); // WHILELE

  public static unsafe Vector<short> CreateWhileLessThanOrEqualMask16Bit(uint left, uint right); // WHILELS

  public static unsafe Vector<short> CreateWhileLessThanOrEqualMask16Bit(ulong left, ulong right); // WHILELS

  public static unsafe Vector<int> CreateWhileLessThanOrEqualMask32Bit(int left, int right); // WHILELE

  public static unsafe Vector<int> CreateWhileLessThanOrEqualMask32Bit(long left, long right); // WHILELE

  public static unsafe Vector<int> CreateWhileLessThanOrEqualMask32Bit(uint left, uint right); // WHILELS

  public static unsafe Vector<int> CreateWhileLessThanOrEqualMask32Bit(ulong left, ulong right); // WHILELS

  public static unsafe Vector<long> CreateWhileLessThanOrEqualMask64Bit(int left, int right); // WHILELE

  public static unsafe Vector<long> CreateWhileLessThanOrEqualMask64Bit(long left, long right); // WHILELE

  public static unsafe Vector<long> CreateWhileLessThanOrEqualMask64Bit(uint left, uint right); // WHILELS

  public static unsafe Vector<long> CreateWhileLessThanOrEqualMask64Bit(ulong left, ulong right); // WHILELS

  public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMask8Bit(int left, int right); // WHILELE

  public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMask8Bit(long left, long right); // WHILELE

  public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMask8Bit(uint left, uint right); // WHILELS

  public static unsafe Vector<sbyte> CreateWhileLessThanOrEqualMask8Bit(ulong left, ulong right); // WHILELS
}

public partial class Sve
{
  // Add new APIs to return unsigned vector

  public static unsafe Vector<ushort> CreateWhileLessThanMaskU16Bit(int left, int right); // WHILELT

  public static unsafe Vector<ushort> CreateWhileLessThanMaskU16Bit(long left, long right); // WHILELT

  public static unsafe Vector<ushort> CreateWhileLessThanMaskU16Bit(uint left, uint right); // WHILELO

  public static unsafe Vector<ushort> CreateWhileLessThanMaskU16Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<uint> CreateWhileLessThanMaskU32Bit(int left, int right); // WHILELT

  public static unsafe Vector<uint> CreateWhileLessThanMaskU32Bit(long left, long right); // WHILELT

  public static unsafe Vector<uint> CreateWhileLessThanMaskU32Bit(uint left, uint right); // WHILELO

  public static unsafe Vector<uint> CreateWhileLessThanMaskU32Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<ulong> CreateWhileLessThanMaskU64Bit(int left, int right); // WHILELT

  public static unsafe Vector<ulong> CreateWhileLessThanMaskU64Bit(long left, long right); // WHILELT

  public static unsafe Vector<ulong> CreateWhileLessThanMaskU64Bit(uint left, uint right); // WHILELO

  public static unsafe Vector<ulong> CreateWhileLessThanMaskU64Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<byte> CreateWhileLessThanMaskU8Bit(int left, int right); // WHILELT

  public static unsafe Vector<byte> CreateWhileLessThanMaskU8Bit(long left, long right); // WHILELT

  public static unsafe Vector<byte> CreateWhileLessThanMaskU8Bit(uint left, uint right); // WHILELO

  public static unsafe Vector<byte> CreateWhileLessThanMaskU8Bit(ulong left, ulong right); // WHILELO

  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskU16Bit(int left, int right); // WHILELE

  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskU16Bit(long left, long right); // WHILELE

  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskU16Bit(uint left, uint right); // WHILELS

  public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMaskU16Bit(ulong left, ulong right); // WHILELS

  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskU32Bit(int left, int right); // WHILELE

  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskU32Bit(long left, long right); // WHILELE

  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskU32Bit(uint left, uint right); // WHILELS

  public static unsafe Vector<uint> CreateWhileLessThanOrEqualMaskU32Bit(ulong left, ulong right); // WHILELS

  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskU64Bit(int left, int right); // WHILELE

  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskU64Bit(long left, long right); // WHILELE

  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskU64Bit(uint left, uint right); // WHILELS

  public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMaskU64Bit(ulong left, ulong right); // WHILELS

  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskU8Bit(int left, int right); // WHILELE

  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskU8Bit(long left, long right); // WHILELE

  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskU8Bit(uint left, uint right); // WHILELS

  public static unsafe Vector<byte> CreateWhileLessThanOrEqualMaskU8Bit(ulong left, ulong right); // WHILELS
}

Risks

dotnet-policy-service[bot] commented 2 months ago

Tagging subscribers to this area: @dotnet/area-system-runtime-intrinsics See info in area-owners.md if you want to be subscribed.

tannergooding commented 2 months ago

If we're needing signs here as well, then we should remove *16Bit and instead do *Int16, *UInt16, etc. Matching the existing conventions for disambiguating.

This should likely be merged with the proposal asking for Single/Double and cover that removal.

a74nh commented 2 months ago

If we're needing signs here as well, then we should remove *16Bit and instead do *Int16, *UInt16, etc. Matching the existing conventions for disambiguating.

Agreed.

This should likely be merged with the proposal asking for Single/Double and cover that removal.

I didn't want to merge initially because I was concerned this would just be rejected outright for adding too many APIs.

I'll close this and merge with the other one.

a74nh commented 2 months ago

APIs added to #108233. Closing this.