dotnet / runtime

.NET is a cross-platform runtime for cloud, mobile, desktop, and IoT apps.
https://docs.microsoft.com/dotnet/core/
MIT License
15.41k stars 4.76k forks source link

[API Proposal]: Add dot product intrinsics to AVX10v2 API #110032

Open DeepakRajendrakumaran opened 1 day ago

DeepakRajendrakumaran commented 1 day ago

Background and motivation

This is a follow up on https://github.com/dotnet/runtime/issues/109083.

This proposal adds the VPDPB[SU,UIU,SS]D[,S] to other already approved AVX10v2 API

Avx10.2 spec. Section10 in this spec goes over these intrinsics

API Proposal

namespace System.Runtime.Intrinsics.X86
{
    /// <summary>Provides access to X86 AVX10.1 hardware instructions via intrinsics</summary>
    [Intrinsic]
    [CLSCompliant(false)]
    public abstract class Avx10v2 : Avx10v1
    {
        // VPDPBSSD xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
        public static Vector128<int> MultiplyWideningAndAdd(vector128<sbyte> left, Vector128<sbyte> right) => MultiplyWideningAndAdd(left, right,);

        // VPDPBSUD xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
        public static Vector128<int> MultiplyWideningAndAdd(vector128<sbyte> left, Vector128<byte> right) => MultiplyWideningAndAdd(left, right,);

        // VPDPBUUD xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
        public static Vector128<int> MultiplyWideningAndAdd(vector128<byte> left, Vector128<byte> right) => MultiplyWideningAndAdd(left, right);

        // VPDPBSSD ymm1{k1}{z}, ymm2, ymm3/m256/m32bcst
        public static Vector256<int> MultiplyWideningAndAdd(Vector256<sbyte> left, Vector256<sbyte> right) => MultiplyWideningAndAdd(left, right);

        // VPDPBSUD ymm1{k1}{z}, ymm2, ymm3/m256/m32bcst
        public static Vector256<int> MultiplyWideningAndAdd(Vector256<sbyte> left, Vector256<byte> right) => MultiplyWideningAndAdd(left, right);

        // VPDPBUUD ymm1{k1}{z}, ymm2, ymm3/m256/m32bcst
        public static Vector256<int> MultiplyWideningAndAdd(Vector256<byte> left, Vector256<byte> right) => MultiplyWideningAndAdd(left, right);

        // VPDPBSSDS xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
        public static Vector128<int> MultiplyWideningAndAddSaturate(vector128<sbyte> left, Vector128<sbyte> right) => MultiplyWideningAndAddSaturate(left, right);

        // VPDPBSUDS xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
        public static Vector128<int> MultiplyWideningAndAddSaturate(vector128<sbyte> left, Vector128<byte> right) => MultiplyWideningAndAddSaturate(left, right);

        // VPDPBUUDS xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
        public static Vector128<int> MultiplyWideningAndAddSaturate(vector128<byte> left, Vector128<byte> right) => MultiplyWideningAndAddSaturate(left, right);

        // VPDPBSSDS ymm1{k1}{z}, ymm2, ymm3/m256/m32bcst
        public static Vector256<int> MultiplyWideningAndAddSaturate(Vector256<sbyte> left, Vector256<sbyte> right) => MultiplyWideningAndAddSaturate(left, right);

        // VPDPBSUDS ymm1{k1}{z}, ymm2, ymm3/m256/m32bcst
        public static Vector256<int> MultiplyWideningAndAddSaturate(Vector256<sbyte> left, Vector256<byte> right) => MultiplyWideningAndAddSaturate(left, right);

        // VPDPBUUDS ymm1{k1}{z}, ymm2, ymm3/m256/m32bcst
        public static Vector256<int> MultiplyWideningAndAddSaturate(Vector256<byte> left, Vector256<byte> right) => MultiplyWideningAndAddSaturate(left, right, acc);

        // VPDPWSUD xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
        public static Vector128<int> MultiplyWideningAndAdd(vector128<short> left, Vector128<ushort> right) => MultiplyWideningAndAdd(left, right, acc);

        // VPDPWUSD xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
        public static Vector128<int> MultiplyWideningAndAdd(vector128<ushort> left, Vector128<short> right) => MultiplyWideningAndAdd(left, right);

        // VPDPWUUD xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
        public static Vector128<int> MultiplyWideningAndAdd(vector128<ushort> left, Vector128<ushort> right) => MultiplyWideningAndAdd(left, right);

        // VPDPWSUD ymm1{k1}{z}, ymm2, ymm3/m256/m32bcst
        public static Vector256<int> MultiplyWideningAndAdd(Vector256<short> left, Vector256<ushort> right) => MultiplyWideningAndAdd(left, right);

        // VPDPWUSD ymm1{k1}{z}, ymm2, ymm3/m256/m32bcst
        public static Vector256<int> MultiplyWideningAndAdd(Vector256<ushort> left, Vector256<short> right) => MultiplyWideningAndAdd(left, right);

        // VPDPWUUD ymm1{k1}{z}, ymm2, ymm3/m256/m32bcst
        public static Vector256<int> MultiplyWideningAndAdd(Vector256<ushort> left, Vector256<ushort> right) => MultiplyWideningAndAdd(left, right);

        // VPDPWSUDS xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
        public static Vector128<int> MultiplyWideningAndAddSaturate(vector128<short> left, Vector128<ushort> right) => MultiplyWideningAndAddSaturate(left, right, acc);

        // VPDPWUSDS xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
        public static Vector128<int> MultiplyWideningAndAddSaturate(vector128<ushort> left, Vector128<short> right) => MultiplyWideningAndAddSaturate(left, right);

        // VPDPWUUDS xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
        public static Vector128<int> MultiplyWideningAndAddSaturate(vector128<ushort> left, Vector128<ushort> right) => MultiplyWideningAndAddSaturate(left, right);

        // VPDPWSUDS ymm1{k1}{z}, ymm2, ymm3/m256/m32bcst
        public static Vector256<int> MultiplyWideningAndAddSaturate(Vector256<short> left, Vector256<ushort> right) => MultiplyWideningAndAddSaturate(left, right);

        // VPDPWUSDS ymm1{k1}{z}, ymm2, ymm3/m256/m32bcst
        public static Vector256<int> MultiplyWideningAndAddSaturate(Vector256<ushort> left, Vector256<short> right) => MultiplyWideningAndAddSaturate(left, right);

        // VPDPWUUDS ymm1{k1}{z}, ymm2, ymm3/m256/m32bcst
        public static Vector256<int> MultiplyWideningAndAddSaturate(Vector256<ushort> left, Vector256<ushort> right) => MultiplyWideningAndAddSaturate(left, right);

        [Intrinsic]
        public abstract class V512 : Avx10v1.V512
        {   
            // VPDPWSUD xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
            public static Vector512<int> MultiplyWideningAndAdd(Vector512<short> left, Vector512<ushort> right) => MultiplyWideningAndAdd(left, right);

            // VPDPWUSD xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
            public static Vector512<int> MultiplyWideningAndAdd(Vector512<ushort> left, Vector512<short> right) => MultiplyWideningAndAdd(left, right);

            // VPDPWUUD xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
            public static Vector512<int> MultiplyWideningAndAdd(Vector512<ushort> left, Vector512<ushort> right) => MultiplyWideningAndAdd(left, right);

            // VPDPWSUDS xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
            public static Vector512<int> MultiplyWideningAndAddSaturate(Vector512<short> left, Vector512<short> right) => MultiplyWideningAndAddSaturate(left, right);

            // VPDPWUSDS xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
            public static Vector512<int> MultiplyWideningAndAddSaturate(Vector512<short> left, Vector512<ushort> right) => MultiplyWideningAndAddSaturate(left, right);

            // VPDPWUUDS xmm1{k1}{z}, xmm2, xmm3/m128/m32bcst
            public static Vector512<int> MultiplyWideningAndAddSaturate(Vector512<ushort> left, Vector512<ushort> right) => MultiplyWideningAndAddSaturate(left, right);

            // VPDPBSSD zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<int> MultiplyWideningAndAdd(Vector512<sbyte> left, Vector512<sbyte> right) => MultiplyWideningAndAdd(left, right);

            // VPDPBSUD zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<int> MultiplyWideningAndAdd(Vector512<sbyte> left, Vector512<byte> right) => MultiplyWideningAndAdd(left, right);

            // VPDPBUUD zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<int> MultiplyWideningAndAdd(Vector512<byte> left, Vector512<byte> right) => MultiplyWideningAndAdd(left, right);

            // VPDPBSSDS zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<int> MultiplyWideningAndAddSaturate(Vector512<sbyte> left, Vector512<sbyte> right) => MultiplyWideningAndAddSaturate(left, right);

            // VPDPBSUDS zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<int> MultiplyWideningAndAddSaturate(Vector512<sbyte> left, Vector512<byte> right) => MultiplyWideningAndAddSaturate(left, right);

            // VPDPBUUDS zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<int> MultiplyWideningAndAddSaturate(Vector512<byte> left, Vector512<byte> right) => MultiplyWideningAndAddSaturate(left, right);
        }
    }
}

API Usage

// Fancy the value
Vector128<byte> v1 = Vector512.Create((byte)someParam1);
Vector128<byte> v2 = Vector512.Create((byte)someParam2);
if (Avx10v2.IsSupported()) {
  Vector128<int> v3 = Avx10v2.MultiplyWideningAndAdd(v1, v2, 0b00000000);
  // etc
}

Alternative Designs

No response

Risks

No response

dotnet-policy-service[bot] commented 1 day ago

Tagging subscribers to this area: @dotnet/area-system-runtime-intrinsics See info in area-owners.md if you want to be subscribed.