Most of these issues were found foremost in the additional complexity and throughput hit that was going to be required for the JIT to integrate the type. However, it also impacted the way users interacted with the types and the public API surface we were to expose. Namely that existing user code would not benefit and it would nearly double the API surface we're currently exposing for the XArch and cross-platform intrinsics.
These considerations were raised with @dotnet/avx512-contrib and an alternative design was proposed where the JIT would do pattern recognition in lowering instead to limit the throughput hit and provide light-up to existing user code. This does not preclude the ability to expose VectorMask in the future and we can revisit the type and its design as appropriate.
Conceptual Differences
Previously, we would have defined the following and this would have expanded to effectively all existing intrinsics exposed. This would nearly double or triple our API surface taking us from the ~1900 APIs we have today up to at least ~3800 APIs. Arm64, as a corallary example, currently has ~2100 APIs.
namespace System.Runtime.Intrinsics.X86;
public static partial class Avx512F
{
// Existing API
public static Vector512<float> Add(Vector512<float> left, Vector512<float> right);
// New mask API
public static Vector512<float> Add(Vector512<float> mergeValues, Vector512Mask<float> mergeMask, Vector512<float> left, Vector512<float> right);
// Potentially handled by just the above overload where `mergeValues: Vector512<float>.Zero`
public static Vector512<float> Add(Vector512Mask<float> zeroMask, Vector512<float> left, Vector512<float> right);
public static partial class VL
{
// New mask API
public static Vector512<float> Add(Vector128<float> mergeValues, Vector128Mask<float> mergeMask, Vector128<float> left, Vector128<float> right);
public static Vector512<float> Add(Vector256<float> mergeValues, Vector256Mask<float> mergeMask, Vector256<float> left, Vector256<float> right);
// Potentially handled by just the above overload where `mergeValues: Vector512<float>.Zero`
public static Vector512<float> Add(Vector128Mask<float> zeroMask, Vector128<float> left, Vector128<float> right);
public static Vector512<float> Add(Vector256Mask<float> zeroMask, Vector256<float> left, Vector256<float> right);
}
}
Pattern Recognition
Rather than exposing these overloads of APIs that take VectorMask<T> and allowing users to explicitly utilize masking, we will instead recognize a few key patterns and transform those in the JIT instead.
We would of also had some intrinsics such as public static Vector512Mask<float> CompareEqual(Vector512<float> left, Vector512<float> right) which produce a mask and various other ways to produce a mask as well. Developers then would've been able to consume this by passing down the mask to the API. For example, in the following we find all additions involving NaN and ensure those elements become 0 in the result.
Thus, by instead recognizing these patterns we can light up existing code and avoid exploding the API surface while also ensuring that the code users aim to write is consistent regardless of whether they are on hardware with native hardware masking or not.
A sampling of the set of patterns we want to recognize include, but are not limited to:
Tagging subscribers to this area: @dotnet/area-system-runtime-intrinsics
See info in area-owners.md if you want to be subscribed.
Issue Details
## Summary
While implementing the API surface for [Expose VectorMask to support generic masking for Vector](https://github.com/dotnet/runtime/issues/74613), various considerations were found that necessitated taking a step back and reconsidering how it works.
Most of these issues were found foremost in the additional complexity and throughput hit that was going to be required for the JIT to integrate the type. However, it also impacted the way users interacted with the types and the public API surface we were to expose. Namely that existing user code would not benefit and it would nearly double the API surface we're currently exposing for the XArch and cross-platform intrinsics.
These considerations were raised with @dotnet/avx512-contrib and an alternative design was proposed where the JIT would do pattern recognition in lowering instead to limit the throughput hit and provide light-up to existing user code. This does not preclude the ability to expose `VectorMask` in the future and we can revisit the type and its design as appropriate.
## Conceptual Differences
Previously, we would have defined the following and this would have expanded to effectively all existing intrinsics exposed. This would nearly double or triple our API surface taking us from the `~1900` APIs we have today up to at least `~3800` APIs. Arm64, as a corallary example, currently has `~2100` APIs.
```csharp
namespace System.Runtime.Intrinsics.X86;
public static partial class Avx512F
{
// Existing API
public static Vector512 Add(Vector512 left, Vector512 right);
// New mask API
public static Vector512 Add(Vector512 mergeValues, Vector512Mask mergeMask, Vector512 left, Vector512 right);
// Potentially handled by just the above overload where `mergeValues: Vector512.Zero`
public static Vector512 Add(Vector512Mask zeroMask, Vector512 left, Vector512 right);
public static partial class VL
{
// New mask API
public static Vector512 Add(Vector128 mergeValues, Vector128Mask mergeMask, Vector128 left, Vector128 right);
public static Vector512 Add(Vector256 mergeValues, Vector256Mask mergeMask, Vector256 left, Vector256 right);
// Potentially handled by just the above overload where `mergeValues: Vector512.Zero`
public static Vector512 Add(Vector128Mask zeroMask, Vector128 left, Vector128 right);
public static Vector512 Add(Vector256Mask zeroMask, Vector256 left, Vector256 right);
}
}
```
## Pattern Recognition
Rather than exposing these overloads of APIs that take `VectorMask` and allowing users to explicitly utilize masking, we will instead recognize a few key patterns and transform those in the JIT instead.
We would of also had some intrinsics such as `public static Vector512Mask CompareEqual(Vector512 left, Vector512 right)` which produce a mask and various other ways to produce a mask as well. Developers then would've been able to consume this by passing down the mask to the API. For example, in the following we find all additions involving `NaN` and ensure those elements become `0` in the result.
```csharp
Vector512Mask nanMask = Avx512F.CompareNotEqual(left, left) | Avx512F.CompareNotEqual(right, right);
return Avx512F.Add(Vector512.Zero, ~nanMask, left, right);
```
If a user wanted to do that today where masking doesn't exist, they'd actually do a functionally similar thing:
```csharp
Vector256 nanMask = Avx.CompareNotEqual(left, left) | Avx.CompareNotEqual(right, right);
Vector256 result = Avx.Add(left, right);
return Vector256.ConditionalSelect(~nanMask, result, Vector256.Zero);
```
Thus, by instead recognizing these patterns we can light up existing code and avoid exploding the API surface while also ensuring that the code users aim to write is consistent regardless of whether they are on hardware with native hardware masking or not.
A sampling of the set of patterns we want to recognize include, but are not limited to:
* `{k1} - ConditionalSelect(mask1, resultVector, mergeVector)`
* `{k1}{z} - ConditionalSelect(mask1, resultVector, Vector.Zero)`
* `kadd k1, k2 - mask1.ExtractMostSignificantBits() + mask2.ExtractMostSignificantBits()`
* `kand k1, k2 - mask1 & mask2`
* `kandn k1, k2 - ~mask1 & mask2`
* `kmov k1, k2 - mask1 = mask2`
* `kmov r32, k1 - mask1.ExtractMostSignificantBits()`
* `kmov k1, r32 - Vector.Create(...).ExtractMostSignificantBits()`
* `knot k1, k2 - ~mask1`
* `kor k1, k2 - mask1 | mask2`
* `kortest k1, k2; jz - (mask1 | mask2) == Vector.Zero`
* `kortest k1, k2; jnz - (mask1 | mask2) != Vector.Zero`
* `kortest k1, k2; jc - (mask1 | mask2) == Vector.AllBitsSet`
* `kortest k1, k2; jnc - (mask1 | mask2) != Vector.AllBitsSet`
* `kshiftl k1, k2, imm8 - mask1.ExtractMostSignificantBits() << amount`
* `kshiftr k1, k2, imm8 - mask1.ExtractMostSignificantBits() >> amount`
* `ktest k1, k2; jz - (mask1 & mask2) == Vector.Zero`
* `ktest k1, k2; jnz - (mask1 & mask2) != Vector.Zero`
* `ktest k1, k2; jc - (~mask1 & mask2) == Vector.Zero`
* `ktest k1, k2; jnc - (~mask1 & mask2) == Vector.Zero`
* `kunpck k1, k2, k3 - UnpackLow(mask1, mask2)`
* `kxnor k1, k2 - ~( mask1 ^ mask2)`
* `kxor k1, k2 - (mask1 ^ mask2)`
* `vpbroadcastm - Vector.Create(mask1)`
* `vpmovm2* - mask1.ExtractMostSignificantBits()`
* `vpmov*2m - vector1.ExtractMostSignificantBits()`
## API Proposal
```csharp
namespace System.Runtime.Intrinsics.X86;
public enum IntComparisonMode : byte
{
Equals = 0,
LessThan = 1,
LessThanOrEqual = 2,
False = 3,
NotEquals = 4,
GreaterThanOrEqual = 5,
GreaterThan = 6,
True = 7,
// Additional names for parity
//
// FloatComparisonMode has similar but they are necessary there since
// `!(x > y)` is not the same as `(x <= y)` due to the existance of NaN
//
// The architecture manual formally uses NotLessThan and NotLessThanOrEqual
NotGreaterThanOrEqual = 1,
NotGreaterThan = 2,
NotLessThan = 5,
NotLessThanOrEqual = 6,
}
public static partial class Avx512F
{
public static Vector512 BlendVariable(Vector512 left, Vector512 right, Vector512 mask);
public static Vector512 BlendVariable(Vector512 left, Vector512 right, Vector512 mask);
public static Vector512 BlendVariable(Vector512 left, Vector512 right, Vector512 mask);
public static Vector512 BlendVariable(Vector512 left, Vector512 right, Vector512 mask);
public static Vector512 BlendVariable(Vector512 left, Vector512 right, Vector512 mask);
public static Vector512 BlendVariable(Vector512 left, Vector512 right, Vector512 mask);
public static Vector512 Compare (Vector512 left, Vector512 right, [ConstantExpected(Max = FloatComparisonMode.UnorderedTrueSignaling)] FloatComparisonMode mode);
public static Vector512 CompareEqual (Vector512 left, Vector512 right);
public static Vector512 CompareGreaterThan (Vector512 left, Vector512 right);
public static Vector512 CompareGreaterThanOrEqual (Vector512 left, Vector512 right);
public static Vector512 CompareLessThan (Vector512 left, Vector512 right);
public static Vector512 CompareLessThanOrEqual (Vector512 left, Vector512 right);
public static Vector512 CompareNotEqual (Vector512 left, Vector512 right);
public static Vector512 CompareNotGreaterThan (Vector512 left, Vector512 right);
public static Vector512 CompareNotGreaterThanOrEqual(Vector512 left, Vector512 right);
public static Vector512 CompareNotLessThan (Vector512 left, Vector512 right);
public static Vector512 CompareNotLessThanOrEqual (Vector512 left, Vector512 right);
public static Vector512 CompareOrdered (Vector512 left, Vector512 right);
public static Vector512 CompareUnordered (Vector512 left, Vector512 right);
public static Vector512 Compare (Vector512 left, Vector512 right, [ConstantExpected(Max = FloatComparisonMode.UnorderedTrueSignaling)] FloatComparisonMode mode);
public static Vector512 CompareEqual (Vector512 left, Vector512 right);
public static Vector512 CompareGreaterThan (Vector512 left, Vector512 right);
public static Vector512 CompareGreaterThanOrEqual (Vector512 left, Vector512 right);
public static Vector512 CompareLessThan (Vector512 left, Vector512 right);
public static Vector512 CompareLessThanOrEqual (Vector512 left, Vector512 right);
public static Vector512 CompareNotEqual (Vector512 left, Vector512 right);
public static Vector512 CompareNotGreaterThan (Vector512 left, Vector512 right);
public static Vector512 CompareNotGreaterThanOrEqual(Vector512 left, Vector512 right);
public static Vector512 CompareNotLessThan (Vector512 left, Vector512 right);
public static Vector512 CompareNotLessThanOrEqual (Vector512 left, Vector512 right);
public static Vector512 CompareOrdered (Vector512 left, Vector512 right);
public static Vector512 CompareUnordered (Vector512 left, Vector512 right);
public static Vector512 Compare (Vector512 left, Vector512 right, [ConstantExpected(Max = IntComparisonMode.True)] IntComparisonMode mode);
public static Vector512 CompareEqual (Vector512 left, Vector512 right);
public static Vector512 CompareGreaterThan (Vector512 left, Vector512 right);
public static Vector512 CompareGreaterThanOrEqual(Vector512 left, Vector512 right);
public static Vector512 CompareLessThan (Vector512 left, Vector512 right);
public static Vector512 CompareLessThanOrEqual (Vector512 left, Vector512 right);
public static Vector512 CompareNotEqual (Vector512 left, Vector512 right);
public static Vector512 Compare (Vector512 left, Vector512 right, [ConstantExpected(Max = IntComparisonMode.True)] IntComparisonMode mode);
public static Vector512 CompareEqual (Vector512 left, Vector512 right);
public static Vector512 CompareGreaterThan (Vector512 left, Vector512 right);
public static Vector512 CompareGreaterThanOrEqual(Vector512 left, Vector512 right);
public static Vector512 CompareLessThan (Vector512 left, Vector512 right);
public static Vector512 CompareLessThanOrEqual (Vector512 left, Vector512 right);
public static Vector512 CompareNotEqual (Vector512 left, Vector512 right);
public static Vector512 Compare (Vector512 left, Vector512 right, [ConstantExpected(Max = IntComparisonMode.True)] IntComparisonMode mode);
public static Vector512 CompareEqual (Vector512 left, Vector512 right);
public static Vector512 CompareGreaterThan (Vector512 left, Vector512 right);
public static Vector512 CompareGreaterThanOrEqual(Vector512 left, Vector512 right);
public static Vector512 CompareLessThan (Vector512 left, Vector512 right);
public static Vector512 CompareLessThanOrEqual (Vector512 left, Vector512 right);
public static Vector512 CompareNotEqual (Vector512 left, Vector512 right);
public static Vector512 Compare (Vector512 left, Vector512 right, [ConstantExpected(Max = IntComparisonMode.True)] IntComparisonMode mode);
public static Vector512 CompareEqual (Vector512 left, Vector512 right);
public static Vector512 CompareGreaterThan (Vector512 left, Vector512 right);
public static Vector512 CompareGreaterThanOrEqual(Vector512 left, Vector512 right);
public static Vector512 CompareLessThan (Vector512 left, Vector512 right);
public static Vector512 CompareLessThanOrEqual (Vector512 left, Vector512 right);
public static Vector512 CompareNotEqual (Vector512 left, Vector512 right);
public static Vector512 Compress(Vector512 value, Vector512 mask);
public static Vector512 Compress(Vector512 value, Vector512 mask);
public static Vector512 Compress(Vector512 value, Vector512 mask);
public static Vector512 Compress(Vector512 value, Vector512 mask);
public static Vector512 Compress(Vector512 value, Vector512 mask);
public static Vector512 Compress(Vector512 value, Vector512 mask);
public static Vector512 Expand(Vector512 value, Vector512 mask);
public static Vector512 Expand(Vector512 value, Vector512
Summary
While implementing the API surface for Expose VectorMask to support generic masking for Vector , various considerations were found that necessitated taking a step back and reconsidering how it works.
Most of these issues were found foremost in the additional complexity and throughput hit that was going to be required for the JIT to integrate the type. However, it also impacted the way users interacted with the types and the public API surface we were to expose. Namely that existing user code would not benefit and it would nearly double the API surface we're currently exposing for the XArch and cross-platform intrinsics.
These considerations were raised with @dotnet/avx512-contrib and an alternative design was proposed where the JIT would do pattern recognition in lowering instead to limit the throughput hit and provide light-up to existing user code. This does not preclude the ability to expose
VectorMask
in the future and we can revisit the type and its design as appropriate.Conceptual Differences
Previously, we would have defined the following and this would have expanded to effectively all existing intrinsics exposed. This would nearly double or triple our API surface taking us from the
~1900
APIs we have today up to at least~3800
APIs. Arm64, as a corallary example, currently has~2100
APIs.Pattern Recognition
Rather than exposing these overloads of APIs that take
VectorMask<T>
and allowing users to explicitly utilize masking, we will instead recognize a few key patterns and transform those in the JIT instead.We would of also had some intrinsics such as
public static Vector512Mask<float> CompareEqual(Vector512<float> left, Vector512<float> right)
which produce a mask and various other ways to produce a mask as well. Developers then would've been able to consume this by passing down the mask to the API. For example, in the following we find all additions involvingNaN
and ensure those elements become0
in the result.If a user wanted to do that today where masking doesn't exist, they'd actually do a functionally similar thing:
Thus, by instead recognizing these patterns we can light up existing code and avoid exploding the API surface while also ensuring that the code users aim to write is consistent regardless of whether they are on hardware with native hardware masking or not.
A sampling of the set of patterns we want to recognize include, but are not limited to:
{k1} - ConditionalSelect(mask1, resultVector, mergeVector)
{k1}{z} - ConditionalSelect(mask1, resultVector, Vector.Zero)
kadd k1, k2 - mask1.ExtractMostSignificantBits() + mask2.ExtractMostSignificantBits()
kand k1, k2 - mask1 & mask2
kandn k1, k2 - ~mask1 & mask2
kmov k1, k2 - mask1 = mask2
kmov r32, k1 - mask1.ExtractMostSignificantBits()
kmov k1, r32 - Vector.Create(...).ExtractMostSignificantBits()
knot k1, k2 - ~mask1
kor k1, k2 - mask1 | mask2
kortest k1, k2; jz - (mask1 | mask2) == Vector.Zero
kortest k1, k2; jnz - (mask1 | mask2) != Vector.Zero
kortest k1, k2; jc - (mask1 | mask2) == Vector.AllBitsSet
kortest k1, k2; jnc - (mask1 | mask2) != Vector.AllBitsSet
kshiftl k1, k2, imm8 - mask1.ExtractMostSignificantBits() << amount
kshiftr k1, k2, imm8 - mask1.ExtractMostSignificantBits() >> amount
ktest k1, k2; jz - (mask1 & mask2) == Vector.Zero
ktest k1, k2; jnz - (mask1 & mask2) != Vector.Zero
ktest k1, k2; jc - (~mask1 & mask2) == Vector.Zero
ktest k1, k2; jnc - (~mask1 & mask2) == Vector.Zero
kunpck k1, k2, k3 - UnpackLow(mask1, mask2)
kxnor k1, k2 - ~( mask1 ^ mask2)
kxor k1, k2 - (mask1 ^ mask2)
vpbroadcastm - Vector.Create(mask1)
vpmovm2* - mask1.ExtractMostSignificantBits()
vpmov*2m - vector1.ExtractMostSignificantBits()
API Proposal