dotnet / runtime

.NET is a cross-platform runtime for cloud, mobile, desktop, and IoT apps.
https://docs.microsoft.com/dotnet/core/
MIT License
14.96k stars 4.65k forks source link

[API Proposal]: AVX-512 Masking helper functions #96986

Closed MineCake147E closed 8 months ago

MineCake147E commented 8 months ago

Background and motivation

87097 aims to implement some AVX-512 mask-related intrinsics.

It was supposed to be a solution to deal with some mask-related things while also retaining the code compatibility with older hardware that supports neither AVX-512 nor SVE2.

It makes sense to use existing Vector*<T> types to define a register (or variable) that behaves like a mask, as we usually do in AVX2, which doesn't even have a physical mask register at all. AVX2 provides a way to treat a vector register as a mask by looking at the most significant bit of each element. But ideas in #87097 extensively rely on RyuJIT pattern-matching, which I don't really think it to be reliable, predictable, and optimal for now.

Even worse, Vector*<T>.ConditionalSelect<T>(condition, left, right) always emits something equivalent to vpternlogd condition, left, right, 0xca, no matter how mask-like the condition is, as the document says:

Conditionally selects a value from two vectors on a bitwise basis.

For the case when the condition is in a mask register, it emits vpmovm2* just to copy the value to the vector register, which takes 3 precious clock cycles to complete on Intel CPUs. In this way, the whole operation is slower than to use the better instruction for that matter, like vpblendm* for AVX-512 environments, if the condition is inside a mask register. On the other hand, when the condition is somehow spilled into a vector register instead, like when it emits multiple mask-writing instructions like vpcmpeqb, bringing a spilled vector register back to a mask register makes it emit vpmov*2m which also takes 3 clock cycles to complete on Intel CPUs. In this case, vpternlogd condition, left, right, 0xca is actually better than to reload the mask register. Vector*<T>.ConditionalSelect<T>(condition, left, right) always chooses the former approach, as the document says, so the dedicated masked blending function should be needed anyway.

Also, as I noted in #92261, pattern-matching that actually changes the behavior of functions drastically hurts the readability of the code, no matter how much effort is being invested to overcome.

Hardware Intrinsics used to be 'specific' in the .NET Core 3.1 days. Every single function specifies the instructions to be executed at least implicitly. .NET 5 didn't change anything about it, or at most made it more readable and portable. .NET 6 just made it more readable and portable. .NET 7 just made it more readable and portable as well. .NET 8 almost broke the consistency by RC 1, as I wrote in #92261, but it reverted its decision in RC 2.

Hardware Intrinsics should be consistent forever, while it should adopt many new instructions and much new optimization techniques. So I need these new APIs to be added in .NET 9 for now.

API Proposal

Cross-platform APIs

BlendVariable is designed in a way that the order of parameters match with Avx512BW.BlendVariable.

namespace System.Runtime.Intrinsics
{
    // This lets analyzer to ensure the parameter is already in a mask register.
    [AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)]
    public sealed class MaskExpectedAttribute : Attribute
    {
    }

    // To be added to `condition` parameter of ConditionalSelect methods
    [AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)]
    public sealed class MaskNotExpectedAttribute : Attribute
    {
    }

    // This lets analyzer to assume the result is stored into a mask register.
    [AttributeUsage(AttributeTargets.ReturnValue, Inherited = false, AllowMultiple = false)]
    public sealed class MaskAttribute : Attribute
    {
    }

    public static class Vector512
    {
        public static bool IsSupported { get; }

        /// <summary>
        /// Performs masked merging operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, otherwise, one from <paramref name="destinationOperand"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="destinationOperand">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : <paramref name="destinationOperand"/></returns>
        public static Vector512<T> MergeWith<T>(this Vector512<T> newValue, Vector512<T> destinationOperand, [MaskExpected] Vector512<T> mask);

        /// <summary>
        /// Performs masked zeroing operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, or 0 otherwise.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : 0</returns>
        public static Vector512<T> ZeroIfNot<T>(this Vector512<T> newValue, [MaskExpected] Vector512<T> mask);

        /// <summary>
        /// Conditionally selects a value from two vectors on a element-wise basis.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="right"/>, otherwise, one from <paramref name="left"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="left">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="right">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="right"/> : <paramref name="left"/></returns>
        public static Vector512<T> BlendVariable<T>(Vector512<T> left, Vector512<T> right, [MaskExpected] Vector512<T> mask);

        /// <summary>
        /// Creates a mask consisting of the most significant bit of each element in a vector.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a mask register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the <paramref name="vector"/>.</typeparam>
        /// <param name="vector">The vector whose elements should have their most significant bit extracted and stored in a mask register.</param>
        /// <returns>The packed most significant bits extracted from the elements in <paramref name="vector"/>.</returns>
        [Mask]
        public static Vector512<T> MaskifyMostSignificantBits<T>(this Vector512<T> vector);

        /// <summary>
        /// Creates a vector which represents a mask.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a vector register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="mask">The mask whose elements should have their most significant bit extracted and stored in a vector register.</param>
        /// <returns>The expanded <paramref name="mask"/> as a vector.</returns>
        public static Vector512<T> VectorifyMask<T>([MaskExpected] this Vector512<T> mask);

        [Mask]
        public static Vector512<float> CreateMaskSingle(short value);
        [Mask]
        public static Vector512<float> CreateMaskSingle(ushort value);
        [Mask]
        public static Vector512<double> CreateMaskDouble(sbyte value);
        [Mask]
        public static Vector512<double> CreateMaskDouble(byte value);
        [Mask]
        public static Vector512<Half> CreateMaskHalf(int value);
        [Mask]
        public static Vector512<Half> CreateMaskHalf(uint value);
        [Mask]
        public static Vector512<byte> CreateMaskByte(long value);
        [Mask]
        public static Vector512<byte> CreateMaskByte(ulong value);
        [Mask]
        public static Vector512<sbyte> CreateMaskSByte(long value);
        [Mask]
        public static Vector512<sbyte> CreateMaskSByte(ulong value);
        [Mask]
        public static Vector512<ushort> CreateMaskUInt16(int value);
        [Mask]
        public static Vector512<ushort> CreateMaskUInt16(uint value);
        [Mask]
        public static Vector512<short> CreateMaskInt16(int value);
        [Mask]
        public static Vector512<short> CreateMaskInt16(uint value);
        [Mask]
        public static Vector512<uint> CreateMaskUInt32(short value);
        [Mask]
        public static Vector512<uint> CreateMaskUInt32(ushort value);
        [Mask]
        public static Vector512<int> CreateMaskInt32(short value);
        [Mask]
        public static Vector512<int> CreateMaskInt32(ushort value);
        [Mask]
        public static Vector512<ulong> CreateMaskUInt64(sbyte value);
        [Mask]
        public static Vector512<ulong> CreateMaskUInt64(byte value);
        [Mask]
        public static Vector512<long> CreateMaskInt64(sbyte value);
        [Mask]
        public static Vector512<long> CreateMaskInt64(byte value);

        /// <summary>
        /// Shifts the value of a mask left by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector512<T> MaskShiftLeft<T>([MaskExpected] Vector512<T> left, [ConstantExpected] byte right);

        /// <summary>
        /// Shifts the value of a mask right by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector512<T> MaskShiftRightLogical<T>([MaskExpected] Vector512<T> left, [ConstantExpected] byte right);
    }

    public static class Vector256
    {
        public static bool IsSupported { get; }

        /// <summary>
        /// Performs masked merging operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, otherwise, one from <paramref name="destinationOperand"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="destinationOperand">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : <paramref name="destinationOperand"/></returns>
        public static Vector256<T> MergeWith<T>(this Vector256<T> newValue, Vector256<T> destinationOperand, [MaskExpected] Vector256<T> mask);

        /// <summary>
        /// Performs masked zeroing operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, or 0 otherwise.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : 0</returns>
        public static Vector256<T> ZeroIfNot<T>(this Vector256<T> newValue, [MaskExpected] Vector256<T> mask);

        /// <summary>
        /// Conditionally selects a value from two vectors on a element-wise basis.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="right"/>, otherwise, one from <paramref name="left"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="left">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="right">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="right"/> : <paramref name="left"/></returns>
        public static Vector256<T> BlendVariable<T>(Vector256<T> left, Vector256<T> right, [MaskExpected] Vector256<T> mask);

        /// <summary>
        /// Creates a mask consisting of the most significant bit of each element in a vector.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a mask register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the <paramref name="vector"/>.</typeparam>
        /// <param name="vector">The vector whose elements should have their most significant bit extracted and stored in a mask register.</param>
        /// <returns>The packed most significant bits extracted from the elements in <paramref name="vector"/>.</returns>
        [Mask]
        public static Vector256<T> MaskifyMostSignificantBits<T>(this Vector256<T> vector);

        /// <summary>
        /// Creates a vector which represents a mask.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a vector register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="mask">The mask whose elements should have their most significant bit extracted and stored in a vector register.</param>
        /// <returns>The expanded <paramref name="mask"/> as a vector.</returns>
        public static Vector256<T> VectorifyMask<T>([MaskExpected] this Vector256<T> mask);

        [Mask]
        public static Vector256<float> CreateMaskSingle(sbyte value);
        [Mask]
        public static Vector256<float> CreateMaskSingle(byte value);
        [Mask]
        public static Vector256<double> CreateMaskDouble(sbyte value);
        [Mask]
        public static Vector256<double> CreateMaskDouble(byte value);
        [Mask]
        public static Vector256<Half> CreateMaskHalf(short value);
        [Mask]
        public static Vector256<Half> CreateMaskHalf(ushort value);
        [Mask]
        public static Vector256<byte> CreateMaskByte(int value);
        [Mask]
        public static Vector256<byte> CreateMaskByte(uint value);
        [Mask]
        public static Vector256<sbyte> CreateMaskSByte(int value);
        [Mask]
        public static Vector256<sbyte> CreateMaskSByte(uint value);
        [Mask]
        public static Vector256<ushort> CreateMaskUInt16(short value);
        [Mask]
        public static Vector256<ushort> CreateMaskUInt16(ushort value);
        [Mask]
        public static Vector256<short> CreateMaskInt16(short value);
        [Mask]
        public static Vector256<short> CreateMaskInt16(ushort value);
        [Mask]
        public static Vector256<uint> CreateMaskUInt32(sbyte value);
        [Mask]
        public static Vector256<uint> CreateMaskUInt32(byte value);
        [Mask]
        public static Vector256<int> CreateMaskInt32(sbyte value);
        [Mask]
        public static Vector256<int> CreateMaskInt32(byte value);
        [Mask]
        public static Vector256<ulong> CreateMaskUInt64(sbyte value);
        [Mask]
        public static Vector256<ulong> CreateMaskUInt64(byte value);
        [Mask]
        public static Vector256<long> CreateMaskInt64(sbyte value);
        [Mask]
        public static Vector256<long> CreateMaskInt64(byte value);

        /// <summary>
        /// Shifts the value of a mask left by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector256<T> MaskShiftLeft<T>([MaskExpected] Vector256<T> left, [ConstantExpected] byte right);

        /// <summary>
        /// Shifts the value of a mask right by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector256<T> MaskShiftRightLogical<T>([MaskExpected] Vector256<T> left, [ConstantExpected] byte right);
    }

    public static class Vector128
    {
        public static bool IsSupported { get; }

        /// <summary>
        /// Performs masked merging operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, otherwise, one from <paramref name="destinationOperand"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="destinationOperand">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : <paramref name="destinationOperand"/></returns>
        public static Vector128<T> MergeWith<T>(this Vector128<T> newValue, Vector128<T> destinationOperand, [MaskExpected] Vector128<T> mask);

        /// <summary>
        /// Performs masked zeroing operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, or 0 otherwise.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : 0</returns>
        public static Vector128<T> ZeroIfNot<T>(this Vector128<T> newValue, [MaskExpected] Vector128<T> mask);

        /// <summary>
        /// Conditionally selects a value from two vectors on a element-wise basis.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="right"/>, otherwise, one from <paramref name="left"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="left">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="right">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="right"/> : <paramref name="left"/></returns>
        public static Vector128<T> BlendVariable<T>(Vector128<T> left, Vector128<T> right, [MaskExpected] Vector128<T> mask);

        /// <summary>
        /// Creates a mask consisting of the most significant bit of each element in a vector.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a mask register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the <paramref name="vector"/>.</typeparam>
        /// <param name="vector">The vector whose elements should have their most significant bit extracted and stored in a mask register.</param>
        /// <returns>The packed most significant bits extracted from the elements in <paramref name="vector"/>.</returns>
        [Mask]
        public static Vector128<T> MaskifyMostSignificantBits<T>(this Vector128<T> vector);

        /// <summary>
        /// Creates a vector which represents a mask.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a vector register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="mask">The mask whose elements should have their most significant bit extracted and stored in a vector register.</param>
        /// <returns>The expanded <paramref name="mask"/> as a vector.</returns>
        public static Vector128<T> VectorifyMask<T>([MaskExpected] this Vector128<T> mask);

        [Mask]
        public static Vector128<float> CreateMaskSingle(sbyte value);
        [Mask]
        public static Vector128<float> CreateMaskSingle(byte value);
        [Mask]
        public static Vector128<double> CreateMaskDouble(sbyte value);
        [Mask]
        public static Vector128<double> CreateMaskDouble(byte value);
        [Mask]
        public static Vector128<Half> CreateMaskHalf(sbyte value);
        [Mask]
        public static Vector128<Half> CreateMaskHalf(byte value);
        [Mask]
        public static Vector128<byte> CreateMaskByte(short value);
        [Mask]
        public static Vector128<byte> CreateMaskByte(ushort value);
        [Mask]
        public static Vector128<sbyte> CreateMaskSByte(short value);
        [Mask]
        public static Vector128<sbyte> CreateMaskSByte(ushort value);
        [Mask]
        public static Vector128<ushort> CreateMaskUInt16(sbyte value);
        [Mask]
        public static Vector128<ushort> CreateMaskUInt16(byte value);
        [Mask]
        public static Vector128<short> CreateMaskInt16(sbyte value);
        [Mask]
        public static Vector128<short> CreateMaskInt16(byte value);
        [Mask]
        public static Vector128<uint> CreateMaskUInt32(sbyte value);
        [Mask]
        public static Vector128<uint> CreateMaskUInt32(byte value);
        [Mask]
        public static Vector128<int> CreateMaskInt32(sbyte value);
        [Mask]
        public static Vector128<int> CreateMaskInt32(byte value);
        [Mask]
        public static Vector128<ulong> CreateMaskUInt64(sbyte value);
        [Mask]
        public static Vector128<ulong> CreateMaskUInt64(byte value);
        [Mask]
        public static Vector128<long> CreateMaskInt64(sbyte value);
        [Mask]
        public static Vector128<long> CreateMaskInt64(byte value);

        /// <summary>
        /// Shifts the value of a mask left by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector128<T> MaskShiftLeft<T>([MaskExpected] Vector128<T> left, [ConstantExpected] byte right);

        /// <summary>
        /// Shifts the value of a mask right by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector128<T> MaskShiftRightLogical<T>([MaskExpected] Vector128<T> left, [ConstantExpected] byte right);
    }
}

Additional AVX-512 Intrinsics

Some bit operations are omitted in this list.

namespace System.Runtime.Intrinsics.X86
{
    public abstract class Avx512F : Avx2
    {
        [Mask]
        public static Vector512<float> MaskUnpack16([MaskExpected] Vector256<float> left, [MaskExpected] Vector256<float> right);
        [Mask]
        public static Vector512<float> MaskUnpack16([MaskExpected] Vector512<double> left, [MaskExpected] Vector512<double> right);
        [Mask]
        public static Vector512<uint> MaskUnpack16([MaskExpected] Vector512<ulong> left, [MaskExpected] Vector512<ulong> right);
        [Mask]
        public static Vector512<int> MaskUnpack16([MaskExpected] Vector512<long> left, [MaskExpected] Vector512<long> right);
        [Mask]
        public static Vector512<uint> MaskUnpack16([MaskExpected] Vector256<uint> left, [MaskExpected] Vector256<uint> right);
        [Mask]
        public static Vector512<int> MaskUnpack16([MaskExpected] Vector256<int> left, [MaskExpected] Vector256<int> right);
        [Mask]
        public static Vector256<Half> MaskUnpack16([MaskExpected] Vector256<float> left, [MaskExpected] Vector256<float> right);
        [Mask]
        public static Vector256<ushort> MaskUnpack16([MaskExpected] Vector256<uint> left, [MaskExpected] Vector256<uint> right);
        [Mask]
        public static Vector256<short> MaskUnpack16([MaskExpected] Vector256<int> left, [MaskExpected] Vector256<int> right);
        [Mask]
        public static Vector256<Half> MaskUnpack16([MaskExpected] Vector128<Half> left, [MaskExpected] Vector128<Half> right);
        [Mask]
        public static Vector256<ushort> MaskUnpack16([MaskExpected] Vector128<ushort> left, [MaskExpected] Vector128<ushort> right);
        [Mask]
        public static Vector256<short> MaskUnpack16([MaskExpected] Vector128<short> left, [MaskExpected] Vector128<short> right);
        [Mask]
        public static Vector128<byte> MaskUnpack16([MaskExpected] Vector128<ushort> left, [MaskExpected] Vector128<ushort> right);
        [Mask]
        public static Vector128<sbyte> MaskUnpack16([MaskExpected] Vector128<short> left, [MaskExpected] Vector128<short> right);

        [Mask]
        public static Vector512<float> MaskXnor([MaskExpected] Vector512<float> left, [MaskExpected] Vector512<float> right);
        [Mask]
        public static Vector256<Half> MaskXnor([MaskExpected] Vector256<Half> left, [MaskExpected] Vector256<Half> right);
        [Mask]
        public static Vector128<byte> MaskXnor([MaskExpected] Vector128<byte> left, [MaskExpected] Vector128<byte> right);
        [Mask]
        public static Vector128<sbyte> MaskXnor([MaskExpected] Vector128<sbyte> left, [MaskExpected] Vector128<sbyte> right);
        [Mask]
        public static Vector256<ushort> MaskXnor([MaskExpected] Vector256<ushort> left, [MaskExpected] Vector256<ushort> right);
        [Mask]
        public static Vector256<short> MaskXnor([MaskExpected] Vector256<short> left, [MaskExpected] Vector256<short> right);
        [Mask]
        public static Vector512<uint> MaskXnor([MaskExpected] Vector512<uint> left, [MaskExpected] Vector512<uint> right);
        [Mask]
        public static Vector512<int> MaskXnor([MaskExpected] Vector512<int> left, [MaskExpected] Vector512<int> right);
    }

    public abstract class Avx512DQ : Avx512F
    {
        [Mask]
        public static Vector512<float> MaskAdd([MaskExpected] Vector512<float> left, [MaskExpected] Vector512<float> right);
        [Mask]
        public static Vector256<Half> MaskAdd([MaskExpected] Vector256<Half> left, [MaskExpected] Vector256<Half> right);
        [Mask]
        public static Vector128<byte> MaskAdd([MaskExpected] Vector128<byte> left, [MaskExpected] Vector128<byte> right);
        [Mask]
        public static Vector128<sbyte> MaskAdd([MaskExpected] Vector128<sbyte> left, [MaskExpected] Vector128<sbyte> right);
        [Mask]
        public static Vector256<ushort> MaskAdd([MaskExpected] Vector256<ushort> left, [MaskExpected] Vector256<ushort> right);
        [Mask]
        public static Vector256<short> MaskAdd([MaskExpected] Vector256<short> left, [MaskExpected] Vector256<short> right);
        [Mask]
        public static Vector512<uint> MaskAdd([MaskExpected] Vector512<uint> left, [MaskExpected] Vector512<uint> right);
        [Mask]
        public static Vector512<int> MaskAdd([MaskExpected] Vector512<int> left, [MaskExpected] Vector512<int> right);
        [Mask]
        public static Vector256<float> MaskAdd([MaskExpected] Vector256<float> left, [MaskExpected] Vector256<float> right);
        [Mask]
        public static Vector512<double> MaskAdd([MaskExpected] Vector512<double> left, [MaskExpected] Vector512<double> right);
        [Mask]
        public static Vector128<Half> MaskAdd([MaskExpected] Vector128<Half> left, [MaskExpected] Vector128<Half> right);
        [Mask]
        public static Vector128<ushort> MaskAdd([MaskExpected] Vector128<ushort> left, [MaskExpected] Vector128<ushort> right);
        [Mask]
        public static Vector128<short> MaskAdd([MaskExpected] Vector128<short> left, [MaskExpected] Vector128<short> right);
        [Mask]
        public static Vector256<uint> MaskAdd([MaskExpected] Vector256<uint> left, [MaskExpected] Vector256<uint> right);
        [Mask]
        public static Vector256<int> MaskAdd([MaskExpected] Vector256<int> left, [MaskExpected] Vector256<int> right);
        [Mask]
        public static Vector512<ulong> MaskAdd([MaskExpected] Vector512<ulong> left, [MaskExpected] Vector512<ulong> right);
        [Mask]
        public static Vector512<long> MaskAdd([MaskExpected] Vector512<long> left, [MaskExpected] Vector512<long> right);

        [Mask]
        public static Vector256<float> MaskXnor([MaskExpected] Vector256<float> left, [MaskExpected] Vector256<float> right);
        [Mask]
        public static Vector512<double> MaskXnor([MaskExpected] Vector512<double> left, [MaskExpected] Vector512<double> right);
        [Mask]
        public static Vector128<Half> MaskXnor([MaskExpected] Vector128<Half> left, [MaskExpected] Vector128<Half> right);
        [Mask]
        public static Vector128<ushort> MaskXnor([MaskExpected] Vector128<ushort> left, [MaskExpected] Vector128<ushort> right);
        [Mask]
        public static Vector128<short> MaskXnor([MaskExpected] Vector128<short> left, [MaskExpected] Vector128<short> right);
        [Mask]
        public static Vector256<uint> MaskXnor([MaskExpected] Vector256<uint> left, [MaskExpected] Vector256<uint> right);
        [Mask]
        public static Vector256<int> MaskXnor([MaskExpected] Vector256<int> left, [MaskExpected] Vector256<int> right);
        [Mask]
        public static Vector512<ulong> MaskXnor([MaskExpected] Vector512<ulong> left, [MaskExpected] Vector512<ulong> right);
        [Mask]
        public static Vector512<long> MaskXnor([MaskExpected] Vector512<long> left, [MaskExpected] Vector512<long> right);

    }

    public abstract class Avx512BW : Avx512F
    {
        [Mask]
        public static Vector512<Half> MaskUnpack32([MaskExpected] Vector512<float> left, [MaskExpected] Vector512<float> right);
        [Mask]
        public static Vector512<Half> MaskUnpack32([MaskExpected] Vector256<Half> left, [MaskExpected] Vector256<Half> right);
        [Mask]
        public static Vector512<ushort> MaskUnpack32([MaskExpected] Vector512<uint> left, [MaskExpected] Vector512<uint> right);
        [Mask]
        public static Vector512<short> MaskUnpack32([MaskExpected] Vector512<int> left, [MaskExpected] Vector512<int> right);
        [Mask]
        public static Vector512<ushort> MaskUnpack32([MaskExpected] Vector256<ushort> left, [MaskExpected] Vector256<ushort> right);
        [Mask]
        public static Vector512<short> MaskUnpack32([MaskExpected] Vector256<short> left, [MaskExpected] Vector256<short> right);
        [Mask]
        public static Vector256<byte> MaskUnpack32([MaskExpected] Vector256<ushort> left, [MaskExpected] Vector256<ushort> right);
        [Mask]
        public static Vector256<sbyte> MaskUnpack32([MaskExpected] Vector256<short> left, [MaskExpected] Vector256<short> right);
        [Mask]
        public static Vector256<byte> MaskUnpack32([MaskExpected] Vector128<byte> left, [MaskExpected] Vector128<byte> right);
        [Mask]
        public static Vector256<sbyte> MaskUnpack32([MaskExpected] Vector128<sbyte> left, [MaskExpected] Vector128<sbyte> right);
        [Mask]
        public static Vector512<byte> MaskUnpack64([MaskExpected] Vector512<ushort> left, [MaskExpected] Vector512<ushort> right);
        [Mask]
        public static Vector512<sbyte> MaskUnpack64([MaskExpected] Vector512<short> left, [MaskExpected] Vector512<short> right);
        [Mask]
        public static Vector512<byte> MaskUnpack64([MaskExpected] Vector256<byte> left, [MaskExpected] Vector256<byte> right);
        [Mask]
        public static Vector512<sbyte> MaskUnpack64([MaskExpected] Vector256<sbyte> left, [MaskExpected] Vector256<sbyte> right);

        [Mask]
        public static Vector512<Half> MaskAdd([MaskExpected] Vector512<Half> left, [MaskExpected] Vector512<Half> right);
        [Mask]
        public static Vector256<byte> MaskAdd([MaskExpected] Vector256<byte> left, [MaskExpected] Vector256<byte> right);
        [Mask]
        public static Vector256<sbyte> MaskAdd([MaskExpected] Vector256<sbyte> left, [MaskExpected] Vector256<sbyte> right);
        [Mask]
        public static Vector512<ushort> MaskAdd([MaskExpected] Vector512<ushort> left, [MaskExpected] Vector512<ushort> right);
        [Mask]
        public static Vector512<short> MaskAdd([MaskExpected] Vector512<short> left, [MaskExpected] Vector512<short> right);
        [Mask]
        public static Vector512<byte> MaskAdd([MaskExpected] Vector512<byte> left, [MaskExpected] Vector512<byte> right);
        [Mask]
        public static Vector512<sbyte> MaskAdd([MaskExpected] Vector512<sbyte> left, [MaskExpected] Vector512<sbyte> right);

        [Mask]
        public static Vector512<Half> MaskXnor([MaskExpected] Vector512<Half> left, [MaskExpected] Vector512<Half> right);
        [Mask]
        public static Vector256<byte> MaskXnor([MaskExpected] Vector256<byte> left, [MaskExpected] Vector256<byte> right);
        [Mask]
        public static Vector256<sbyte> MaskXnor([MaskExpected] Vector256<sbyte> left, [MaskExpected] Vector256<sbyte> right);
        [Mask]
        public static Vector512<ushort> MaskXnor([MaskExpected] Vector512<ushort> left, [MaskExpected] Vector512<ushort> right);
        [Mask]
        public static Vector512<short> MaskXnor([MaskExpected] Vector512<short> left, [MaskExpected] Vector512<short> right);
        [Mask]
        public static Vector512<byte> MaskXnor([MaskExpected] Vector512<byte> left, [MaskExpected] Vector512<byte> right);
        [Mask]
        public static Vector512<sbyte> MaskXnor([MaskExpected] Vector512<sbyte> left, [MaskExpected] Vector512<sbyte> right);
    }
}

API Usage

Merge-masking and zero-masking can be written like:

zmm0 = Avx512BW.Subtract(zmm1, zmm2).MergeWith(zmm0, k1);   // vpsubb zmm0 {k1}, zmm1, zmm2
zmm0 = Avx512BW.Subtract(zmm3, zmm4).ZeroIfNot(k2);         // vpsubb zmm0 {k2}{z}, zmm3, zmm4

Alternative Designs

There may be better names for each identifier.

Risks

None I can come up with.

ghost commented 8 months ago

Tagging subscribers to this area: @dotnet/area-system-runtime-intrinsics See info in area-owners.md if you want to be subscribed.

Issue Details
### Background and motivation #87097 aims to implement some AVX-512 mask-related intrinsics. It was supposed to be a solution to deal with some mask-related things while also retaining the code compatibility with older hardware that supports neither AVX-512 nor SVE2. It makes sense to use existing `Vector*` types to define a register (or variable) that behaves like a mask, as we usually do in AVX2, which doesn't even have a physical mask register at all. AVX2 provides a way to treat a vector register as a mask by looking at the most significant bit of each element. But ideas in #87097 extensively rely on RyuJIT pattern-matching, which I don't really think it to be reliable, predictable, and optimal for now. Even worse, `Vector*.ConditionalSelect(condition, left, right)` always emits something equivalent to `vpternlogd condition, left, right, 0xca`, **no matter how mask-like the `condition` is**, as [the document](https://learn.microsoft.com/en-us/dotnet/api/system.runtime.intrinsics.vector512.conditionalselect?view=net-8.0) says: > Conditionally selects a value from two vectors on a bitwise basis. For the case when the condition is in a mask register, it emits `vpmovm2*` just to copy the value to the vector register, which takes 3 **precious** clock cycles to complete on Intel CPUs. In this way, the whole operation is slower than to use the better instruction for that matter, like `vpblendm*` for AVX-512 environments, if the condition is inside a mask register. On the other hand, when the condition is somehow spilled into a vector register instead, like when it emits multiple mask-writing instructions like `vpcmpeqb`, bringing a spilled vector register back to a mask register makes it emit `vpmov*2m` which also takes 3 clock cycles to complete on Intel CPUs. In this case, `vpternlogd condition, left, right, 0xca` is actually better than to reload the mask register. `Vector*.ConditionalSelect(condition, left, right)` always chooses the former approach, as the document says, so the dedicated masked blending function should be needed anyway. Also, as I noted in #92261, pattern-matching that actually changes the behavior of functions drastically hurts the readability of the code, no matter how much effort is being invested to overcome. Hardware Intrinsics used to be 'specific' in the .NET Core 3.1 days. Every single function specifies the instructions to be executed at least implicitly. .NET 5 didn't change anything about it, or at least made it more readable and portable. .NET 6 just made it more readable and portable. .NET 7 just made it more readable and portable as well. .NET 8 almost broke the consistency by RC 1, as I wrote in #92261, but it reverted its decision in RC 2. Hardware Intrinsics should be consistent forever, while it should adopt many new instructions and much new optimization techniques. So I need these new APIs to be added in .NET 9 for now. ### API Proposal #### Cross-platform APIs `BlendVariable` is designed in a way that the order of parameters match with `Avx512BW.BlendVariable`. ```csharp namespace System.Runtime.Intrinsics { // This lets analyzer to ensure the parameter is already in a mask register. [AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)] public sealed class MaskExpectedAttribute : Attribute { } // To be added to `condition` parameter of ConditionalSelect methods [AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)] public sealed class MaskNotExpectedAttribute : Attribute { } // This lets analyzer to assume the result is stored into a mask register. [AttributeUsage(AttributeTargets.ReturnValue, Inherited = false, AllowMultiple = false)] public sealed class MaskAttribute : Attribute { } public static class Vector512 { public static bool IsSupported { get; } /// /// Performs masked merging operation. /// If corresponding element of is 1, then the element would be one from , otherwise, one from . /// /// The type of the elements in the vector. /// The values to be selected when the corresponding element in is true. /// The values to be selected when the corresponding element in is false. /// The mask to control which value to take. /// For each elements, ? : public static Vector512 MergeWith(this Vector512 newValue, Vector512 destinationOperand, [MaskExpected] Vector512 mask); /// /// Performs masked zeroing operation. /// If corresponding element of is 1, then the element would be one from , or 0 otherwise. /// /// The type of the elements in the vector. /// The values to be selected when the corresponding element in is true. /// The mask to control which value to take. /// For each elements, ? : 0 public static Vector512 ZeroIfNot(this Vector512 newValue, [MaskExpected] Vector512 mask); /// /// Conditionally selects a value from two vectors on a element-wise basis. /// If corresponding element of is 1, then the element would be one from , otherwise, one from . /// /// The type of the elements in the vector. /// The values to be selected when the corresponding element in is false. /// The values to be selected when the corresponding element in is true. /// The mask to control which value to take. /// For each elements, ? : public static Vector512 BlendVariable(Vector512 left, Vector512 right, [MaskExpected] Vector512 mask); /// /// Creates a mask consisting of the most significant bit of each element in a vector. /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a mask register. /// /// The type of the elements in the . /// The vector whose elements should have their most significant bit extracted and stored in a mask register. /// The packed most significant bits extracted from the elements in . [Mask] public static Vector512 MaskifyMostSignificantBits(this Vector512 vector); /// /// Creates a vector which represents a mask. /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a vector register. /// /// The type of the elements in a vector to mask with. /// The mask whose elements should have their most significant bit extracted and stored in a vector register. /// The expanded as a vector. public static Vector512 VectorifyMask([MaskExpected] this Vector512 mask); [Mask] public static Vector512 CreateMaskSingle(short value); [Mask] public static Vector512 CreateMaskSingle(ushort value); [Mask] public static Vector512 CreateMaskDouble(sbyte value); [Mask] public static Vector512 CreateMaskDouble(byte value); [Mask] public static Vector512 CreateMaskHalf(int value); [Mask] public static Vector512 CreateMaskHalf(uint value); [Mask] public static Vector512 CreateMaskByte(long value); [Mask] public static Vector512 CreateMaskByte(ulong value); [Mask] public static Vector512 CreateMaskSByte(long value); [Mask] public static Vector512 CreateMaskSByte(ulong value); [Mask] public static Vector512 CreateMaskUInt16(int value); [Mask] public static Vector512 CreateMaskUInt16(uint value); [Mask] public static Vector512 CreateMaskInt16(int value); [Mask] public static Vector512 CreateMaskInt16(uint value); [Mask] public static Vector512 CreateMaskUInt32(short value); [Mask] public static Vector512 CreateMaskUInt32(ushort value); [Mask] public static Vector512 CreateMaskInt32(short value); [Mask] public static Vector512 CreateMaskInt32(ushort value); [Mask] public static Vector512 CreateMaskUInt64(sbyte value); [Mask] public static Vector512 CreateMaskUInt64(byte value); [Mask] public static Vector512 CreateMaskInt64(sbyte value); [Mask] public static Vector512 CreateMaskInt64(byte value); /// /// Shifts the value of a mask left by the specified amount. /// /// The type of the elements in a vector to mask with. /// The mask that represents the bits to be shifted. /// The number of bits to shift. /// A mask that represents the shifted value. [Mask] public static Vector512 MaskShiftLeft([MaskExpected] Vector512 left, [ConstantExpected] byte right); /// /// Shifts the value of a mask right by the specified amount. /// /// The type of the elements in a vector to mask with. /// The mask that represents the bits to be shifted. /// The number of bits to shift. /// A mask that represents the shifted value. [Mask] public static Vector512 MaskShiftRightLogical([MaskExpected] Vector512 left, [ConstantExpected] byte right); } public static class Vector256 { public static bool IsSupported { get; } /// /// Performs masked merging operation. /// If corresponding element of is 1, then the element would be one from , otherwise, one from . /// /// The type of the elements in the vector. /// The values to be selected when the corresponding element in is true. /// The values to be selected when the corresponding element in is false. /// The mask to control which value to take. /// For each elements, ? : public static Vector256 MergeWith(this Vector256 newValue, Vector256 destinationOperand, [MaskExpected] Vector256 mask); /// /// Performs masked zeroing operation. /// If corresponding element of is 1, then the element would be one from , or 0 otherwise. /// /// The type of the elements in the vector. /// The values to be selected when the corresponding element in is true. /// The mask to control which value to take. /// For each elements, ? : 0 public static Vector256 ZeroIfNot(this Vector256 newValue, [MaskExpected] Vector256 mask); /// /// Conditionally selects a value from two vectors on a element-wise basis. /// If corresponding element of is 1, then the element would be one from , otherwise, one from . /// /// The type of the elements in the vector. /// The values to be selected when the corresponding element in is false. /// The values to be selected when the corresponding element in is true. /// The mask to control which value to take. /// For each elements, ? : public static Vector256 BlendVariable(Vector256 left, Vector256 right, [MaskExpected] Vector256 mask); /// /// Creates a mask consisting of the most significant bit of each element in a vector. /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a mask register. /// /// The type of the elements in the . /// The vector whose elements should have their most significant bit extracted and stored in a mask register. /// The packed most significant bits extracted from the elements in . [Mask] public static Vector256 MaskifyMostSignificantBits(this Vector256 vector); /// /// Creates a vector which represents a mask. /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a vector register. /// /// The type of the elements in a vector to mask with. /// The mask whose elements should have their most significant bit extracted and stored in a vector register. /// The expanded as a vector. public static Vector256 VectorifyMask([MaskExpected] this Vector256 mask); [Mask] public static Vector256 CreateMaskSingle(sbyte value); [Mask] public static Vector256 CreateMaskSingle(byte value); [Mask] public static Vector256 CreateMaskDouble(sbyte value); [Mask] public static Vector256 CreateMaskDouble(byte value); [Mask] public static Vector256 CreateMaskHalf(short value); [Mask] public static Vector256 CreateMaskHalf(ushort value); [Mask] public static Vector256 CreateMaskByte(int value); [Mask] public static Vector256 CreateMaskByte(uint value); [Mask] public static Vector256 CreateMaskSByte(int value); [Mask] public static Vector256 CreateMaskSByte(uint value); [Mask] public static Vector256 CreateMaskUInt16(short value); [Mask] public static Vector256 CreateMaskUInt16(ushort value); [Mask] public static Vector256 CreateMaskInt16(short value); [Mask] public static Vector256 CreateMaskInt16(ushort value); [Mask] public static Vector256 CreateMaskUInt32(sbyte value); [Mask] public static Vector256 CreateMaskUInt32(byte value); [Mask] public static Vector256 CreateMaskInt32(sbyte value); [Mask] public static Vector256 CreateMaskInt32(byte value); [Mask] public static Vector256 CreateMaskUInt64(sbyte value); [Mask] public static Vector256 CreateMaskUInt64(byte value); [Mask] public static Vector256 CreateMaskInt64(sbyte value); [Mask] public static Vector256 CreateMaskInt64(byte value); /// /// Shifts the value of a mask left by the specified amount. /// /// The type of the elements in a vector to mask with. /// The mask that represents the bits to be shifted. /// The number of bits to shift. /// A mask that represents the shifted value. [Mask] public static Vector256 MaskShiftLeft([MaskExpected] Vector256 left, [ConstantExpected] byte right); /// /// Shifts the value of a mask right by the specified amount. /// /// The type of the elements in a vector to mask with. /// The mask that represents the bits to be shifted. /// The number of bits to shift. /// A mask that represents the shifted value. [Mask] public static Vector256 MaskShiftRightLogical([MaskExpected] Vector256 left, [ConstantExpected] byte right); } public static class Vector128 { public static bool IsSupported { get; } /// /// Performs masked merging operation. /// If corresponding element of is 1, then the element would be one from , otherwise, one from . /// /// The type of the elements in the vector. /// The values to be selected when the corresponding element in is true. /// The values to be selected when the corresponding element in is false. /// The mask to control which value to take. /// For each elements, ? : public static Vector128 MergeWith(this Vector128 newValue, Vector128 destinationOperand, [MaskExpected] Vector128 mask); /// /// Performs masked zeroing operation. /// If corresponding element of is 1, then the element would be one from , or 0 otherwise. /// /// The type of the elements in the vector. /// The values to be selected when the corresponding element in is true. /// The mask to control which value to take. /// For each elements, ? : 0 public static Vector128 ZeroIfNot(this Vector128 newValue, [MaskExpected] Vector128 mask); /// /// Conditionally selects a value from two vectors on a element-wise basis. /// If corresponding element of is 1, then the element would be one from , otherwise, one from . /// /// The type of the elements in the vector. /// The values to be selected when the corresponding element in is false. /// The values to be selected when the corresponding element in is true. /// The mask to control which value to take. /// For each elements, ? : public static Vector128 BlendVariable(Vector128 left, Vector128 right, [MaskExpected] Vector128 mask); /// /// Creates a mask consisting of the most significant bit of each element in a vector. /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a mask register. /// /// The type of the elements in the . /// The vector whose elements should have their most significant bit extracted and stored in a mask register. /// The packed most significant bits extracted from the elements in . [Mask] public static Vector128 MaskifyMostSignificantBits(this Vector128 vector); /// /// Creates a vector which represents a mask. /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a vector register. /// /// The type of the elements in a vector to mask with. /// The mask whose elements should have their most significant bit extracted and stored in a vector register. /// The expanded as a vector. public static Vector128 VectorifyMask([MaskExpected] this Vector128 mask); [Mask] public static Vector128 CreateMaskSingle(sbyte value); [Mask] public static Vector128 CreateMaskSingle(byte value); [Mask] public static Vector128 CreateMaskDouble(sbyte value); [Mask] public static Vector128 CreateMaskDouble(byte value); [Mask] public static Vector128 CreateMaskHalf(sbyte value); [Mask] public static Vector128 CreateMaskHalf(byte value); [Mask] public static Vector128 CreateMaskByte(short value); [Mask] public static Vector128 CreateMaskByte(ushort value); [Mask] public static Vector128 CreateMaskSByte(short value); [Mask] public static Vector128 CreateMaskSByte(ushort value); [Mask] public static Vector128 CreateMaskUInt16(sbyte value); [Mask] public static Vector128 CreateMaskUInt16(byte value); [Mask] public static Vector128 CreateMaskInt16(sbyte value); [Mask] public static Vector128 CreateMaskInt16(byte value); [Mask] public static Vector128 CreateMaskUInt32(sbyte value); [Mask] public static Vector128 CreateMaskUInt32(byte value); [Mask] public static Vector128 CreateMaskInt32(sbyte value); [Mask] public static Vector128 CreateMaskInt32(byte value); [Mask] public static Vector128 CreateMaskUInt64(sbyte value); [Mask] public static Vector128 CreateMaskUInt64(byte value); [Mask] public static Vector128 CreateMaskInt64(sbyte value); [Mask] public static Vector128 CreateMaskInt64(byte value); /// /// Shifts the value of a mask left by the specified amount. /// /// The type of the elements in a vector to mask with. /// The mask that represents the bits to be shifted. /// The number of bits to shift. /// A mask that represents the shifted value. [Mask] public static Vector128 MaskShiftLeft([MaskExpected] Vector128 left, [ConstantExpected] byte right); /// /// Shifts the value of a mask right by the specified amount. /// /// The type of the elements in a vector to mask with. /// The mask that represents the bits to be shifted. /// The number of bits to shift. /// A mask that represents the shifted value. [Mask] public static Vector128 MaskShiftRightLogical([MaskExpected] Vector128 left, [ConstantExpected] byte right); } } ``` #### Additional AVX-512 Intrinsics Some bit operations are omitted in this list. ```csharp namespace System.Runtime.Intrinsics.X86 { public abstract class Avx512F : Avx2 { [Mask] public static Vector512 MaskUnpack16([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector512 MaskUnpack16([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskUnpack16([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskUnpack16([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskUnpack16([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector512 MaskUnpack16([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector256 MaskUnpack16([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector256 MaskUnpack16([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector256 MaskUnpack16([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector256 MaskUnpack16([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector256 MaskUnpack16([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector256 MaskUnpack16([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector128 MaskUnpack16([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector128 MaskUnpack16([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector512 MaskXnor([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector256 MaskXnor([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector128 MaskXnor([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector128 MaskXnor([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector256 MaskXnor([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector256 MaskXnor([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector512 MaskXnor([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskXnor([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); } public abstract class Avx512DQ : Avx512F { [Mask] public static Vector512 MaskAdd([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector256 MaskAdd([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector128 MaskAdd([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector128 MaskAdd([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector256 MaskAdd([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector256 MaskAdd([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector512 MaskAdd([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskAdd([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector256 MaskAdd([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector512 MaskAdd([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector128 MaskAdd([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector128 MaskAdd([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector128 MaskAdd([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector256 MaskAdd([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector256 MaskAdd([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector512 MaskAdd([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskAdd([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector256 MaskXnor([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector512 MaskXnor([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector128 MaskXnor([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector128 MaskXnor([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector128 MaskXnor([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector256 MaskXnor([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector256 MaskXnor([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector512 MaskXnor([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskXnor([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); } public abstract class Avx512BW : Avx512F { [Mask] public static Vector512 MaskUnpack32([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskUnpack32([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector512 MaskUnpack32([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskUnpack32([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskUnpack32([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector512 MaskUnpack32([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector256 MaskUnpack32([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector256 MaskUnpack32([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector256 MaskUnpack32([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector256 MaskUnpack32([MaskExpected] Vector128 left, [MaskExpected] Vector128 right); [Mask] public static Vector512 MaskUnpack64([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskUnpack64([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskUnpack64([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector512 MaskUnpack64([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector512 MaskAdd([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector256 MaskAdd([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector256 MaskAdd([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector512 MaskAdd([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskAdd([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskAdd([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskAdd([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskXnor([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector256 MaskXnor([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector256 MaskXnor([MaskExpected] Vector256 left, [MaskExpected] Vector256 right); [Mask] public static Vector512 MaskXnor([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskXnor([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskXnor([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); [Mask] public static Vector512 MaskXnor([MaskExpected] Vector512 left, [MaskExpected] Vector512 right); } } ``` ### API Usage Merge-masking and zero-masking can be written like: ```csharp zmm0 = Avx512BW.Subtract(zmm1, zmm2).MergeWith(zmm0, k1); // vpsubb zmm0 {k1}, zmm1, zmm2 zmm0 = Avx512BW.Subtract(zmm3, zmm4).ZeroIfNot(k2); // vpsubb zmm0 {k2}{z}, zmm3, zmm4 ``` ### Alternative Designs There may be better names for each identifier. ### Risks None I can come up with.
Author: MineCake147E
Assignees: -
Labels: `api-suggestion`, `area-System.Runtime.Intrinsics`
Milestone: -
tannergooding commented 8 months ago

Which I don't really think it to be reliable, predictable, and optimal for now.

This is a point in time problem and was a known limitation around the release of AVX-512 in .NET 8 due to the massive scope and size of the feature. There were multiple areas related to masking that didn't make the cut in the first implementation and which are being worked on for .NET 9. That includes implementing the necessary pattern recognition and enabling optimizations or alternative instruction emission for various cases, particularly when it comes to the xplat APIs.

Hardware Intrinsics should be consistent forever, while it should adopt many new instructions and much new optimization techniques.

The intrinsics space needs to evolve and fit the needs of a growing number of platforms while continuing to avoid massive amounts of bloat, complexity, and other problematic considerations that come about from newer concepts and how they impact the overall JIT, whether they only add cost if used vs overhead even if not used, etc.

The intrinsics space has always used and relied on pattern recognition in a number of scenarios, especially when it comes to "embedded operations" as are common on x86/x64 (such as embedded loads/stores). Embedded masking is no different.


This proposal, particularly as it applies to the xplat APIs proposed, isn't going to help with codegen. It's only going to complicate the required recognition and make it harder to use the APIs on platforms where masking doesn't exist.

For the platform-specific APIs, there are a couple concepts that might be worth exposing such as adding or unpacking masks given that the behavior is different from doing the same with two vectors that are known to be allbitsset or zero per-element. However, there are also ones proposed like MaskXnor which do not make sense to expose because there is no difference in behavior from a regular Xnor operation.

I'd recommend opening issues to track individual cases where the pattern recognition around masking doesn't work as expected (such as no embedded masking support today or the case where ConditionalSelect could emit vpblendm instead of vpternlog).

I'd then recommend opening a standalone proposal for the mask like concepts which can't be trivially handled, such as addition and unpacking.

MineCake147E commented 8 months ago

I'd then recommend opening a standalone proposal for the mask like concepts which can't be trivially handled, such as addition and unpacking.

I'd recommend opening issues to track individual cases where the pattern recognition around masking doesn't work as expected (such as no embedded masking support today or the case where ConditionalSelect could emit vpblendm instead of vpternlog).

Sure. I will.

This proposal, particularly as it applies to the xplat APIs proposed, isn't going to help with codegen. It's only going to complicate the required recognition and make it harder to use the APIs on platforms where masking doesn't exist.

I'm afraid I forgot to mention that MergeWith and ZeroIfNot were ideas mainly from readability issues. CreateMask* were for kmov* instructions. MaskifyMostSignificantBits was for vpmov*2m. VectorifyMask was for vpmovm2*. MaskShift* could be necessary as some of alternative approach, like ones with one or more AlignRight on different element types, or one uses Permute*, can't be recognized very easily. The same applies for CreateMask*.

I personally think that they don't hurt anything on platforms where masking doesn't exist, as most of their fallback code could easily be implemented using non-mask instructions.

MineCake147E commented 8 months ago

Also, covering all known/potential workaround with pattern matching doesn't seem to be a good idea anyway. It bloats the RyuJIT up pretty quickly, or even end up missing the optimization opportunities of large methods more because of extended execution time of RyuJIT optimizing a method. CreateMask*, MaskifyMostSignificantBits, VectorifyMask, and MaskShift* are the ones which can't be implemented trivially without corresponding instructions.

MineCake147E commented 8 months ago

I forgot to mention, though, the optimal code for CPUs without masking support could be suboptimal in CPUs with masking support.

Consider subtracting short values in ymm1 from ones in ymm0, if the corresponding mask bit represents true.

A code optimized for AVX2 looks like if ymm2 is the mask here:

vpand ymm1, ymm1, ymm2
vpsubw ymm0, ymm0, ymm1

A code optimized for AVX-512 looks like if k1 is the mask here:

vpsubw ymm0 {k1}, ymm0, ymm1

Can future RyuJIT be able to recognize ymm0 -= k1 & ymm1 and emit the code above? Or if ymm2 were given as a mask instead, here's what an optimized code looks like instead:

vpand ymm1, ymm1, ymm2
vpsubw ymm0, ymm0, ymm1

Can RyuJIT be able to recognize ymm0 = Avx512BW.BlendVariable(ymm0, ymm0 - ymm1, ymm2) and emit the code above as well?

Consider multiplying short values in ymm0 by 16, if the corresponding mask bit represents true.

A code optimized for AVX2 looks like if ymm1 is the mask here:

vpsllw ymm2, ymm0, 4
vpblendvb ymm0, ymm0, ymm2, ymm1

A code optimized for AVX-512 looks like if k1 is the mask here:

vpsllw ymm0 {k1}, ymm0, 4

Can RyuJIT recognize ymm0 <<= k1 & Vector512.Create((short)4) and emit the code above in the future? Or if ymm1 were given as a mask instead, here's what an optimized code looks like instead:

vpandd ymm1, ymm1, dword ptr [rip + .DISPLACEMENT]{1to8} ; The memory address stores 0x0004_0004
vpsllvw ymm0, ymm0, ymm1

Can RyuJIT recognize ymm0 = Avx512BW.BlendVariable(ymm0, Avx512BW.ShiftLeftLogical(ymm0, 4), ymm1) and emit the code above in the future?

I think that the user should be aware of masking availability anyway. No matter how far RyuJIT evolves, the optimal C# code should vary from platform to platform in some cases, as I showed above. Even LLVM sometimes fails to generate optimal code today. It is nearly impossible to create a compiler that optimizes the code perfectly at all times.

tannergooding commented 8 months ago

I'm afraid I forgot to mention that MergeWith and ZeroIfNot were ideas mainly from readability issues.

I don't think they help that much with readability and there are definitely some ambiguities in how the names can be interpreted.

On the other hand, ConditionalSelect is very clear on what it does and fits the pattern that people have already been following for years. Users that "really" want to have some kind of helper like MergeWith can trivially define it over ConditionalSelect as an extension method.

CreateMask were for kmov instructions. ...

As indicated, any APIs which can't be trivially handled via pattern recognition, such as because the naive operation isn't clear or 1-to-1, can and likely should have platform specific APIs exposed. So exposing an API like Avx512F.AddMask is fine and an appropriate proposal should be opened. Having an xplat API like Vector128.AddMask is likely not a good idea on the other hand and will lead to pessimizations in xplat code.

It bloats the RyuJIT up pretty quickly

It does not and is significantly less expensive than the alternative, which is one of the reasons why we went with it.

Also, covering all known/potential workaround with pattern matching doesn't seem to be a good idea anyway.

It is not a goal to cover every potential pattern, no compiler does this. Even LLVM doesn't cover "everything".

It is instead a goal of the compiler to cover common/typical patterns. Users can request that a new pattern be recognized as well, but it all comes down to cost, complexity, and benefit.

There is some responsibility on the developer, especially in perf critical code, to write their code in a way that fits the well-established and recognized patterns to ensure the underlying compiler can do its job.

could be suboptimal in CPUs with masking support.

It is not the goal of the compiler to generate "ideal" codegen in every possible scenario. This is effectively impossible and even LLVM when you're targeting a specific micro-architecture gets it wrong in many cases.

At the end of the day, losing 1-3 cycles isn't going to matter for many code patterns and the smallest code isn't necessarily the fastest code. Many of the examples you've provided aren't necessarily standard/common patterns that would be encountered for typical SIMD code, and where they are encountered, the difference between the given and suggested codegen isn't significant. It's within the realm of noise, especially when compared to the general performance deltas caused by inherent latency (cache, ram, disk, etc), variable processor speeds, resource contention between cores/hyper-threads, etc.

You might be able to measure the difference in a micro-benchmark, but the difference is unlikely to surface in most real world applications, unless they are highly specialized and those lines happen to be the bottleneck over gigabytes of data.

I think that the user should be aware of masking availability anyway. No matter how far RyuJIT evolves, the optimal C# code should vary from platform to platform in some cases, as I showed above.

Most code doesn't need to be "optimal" and as the number of platforms and scenarios needing to be supported expands, so does the need to write portable and reusable code.

The BCL is explicitly utilizing the xplat intrinsics, with selective usage of the platform specific intrinsics, in most code paths because losing a nanosecond here or there is acceptable for the massively reduced complexity, the increased confidence that the code is working as expected, and the ability to rapidly bring online 2-16x perf gains for new platforms. -- Even if a 20x perf gain is possible with hand tuned code, its a diminishing return, especially when viewed in the context of typical applications.

Even LLVM sometimes fails to generate optimal code today.

Yes, and there are places where RyuJIT provides better controls/guarantees than LLVM provides. There is also the inverse case where LLVM does it better.

At the end of the day, we are confident we will be able to get the most common patterns and support in such that you will be able to get nearly ideal code for the vast majority of scenarios. There will be some places over time that users will surface as needing improvements and we will look at those as they come in, determining whether the pattern can/should be supported or if the user should modify their code slightly to fit in with the already recognized patterns.

MineCake147E commented 8 months ago

Oh I get it now.

As indicated, any APIs which can't be trivially handled via pattern recognition, such as because the naive operation isn't clear or 1-to-1, can and likely should have platform specific APIs exposed. So exposing an API like Avx512F.AddMask is fine and an appropriate proposal should be opened.

So am I allowed to open one for MaskShift*, VectorifyMask, MaskifyMostSignificantBits, and CreateMask* to be included in Avx512* instead then?

I opened one for MaskAdd and MaskUnpack* yesterday by the way.

tannergooding commented 8 months ago

So am I allowed to open one for MaskShift, VectorifyMask, MaskifyMostSignificantBits, and CreateMask to be included in Avx512* instead then?

Most of these APIs should be *Mask rather than Mask*. We already have Mask* APIs (such as Sse.MaskStore) and using it as a postfix also better matches other APIs where the type appears at the end (LoadVector128, ConvertToDouble, AsInt64, etc).

As for the ones listed: