KhronosGroup / SPIRV-Cross

SPIRV-Cross is a practical tool and library for performing reflection on SPIR-V and disassembling SPIR-V back to high level languages.
Apache License 2.0
1.96k stars 549 forks source link

Spir-V to MSL: OpSMulExtended traslation incorrectly casts int64 to int32 before mul and mulhi #2300

Closed aitor-lunarg closed 2 months ago

aitor-lunarg commented 3 months ago

Given the following Spir-V code:

; SPIR-V
; Version: 1.0
; Generator: Khronos Glslang Reference Front End; 10
; Bound: 50
; Schema: 0
               OpCapability Shader
               OpCapability Int64
          %1 = OpExtInstImport "GLSL.std.450"
               OpMemoryModel Logical GLSL450
               OpEntryPoint GLCompute %main "main"
               OpExecutionMode %main LocalSize 1 1 1
               OpSource GLSL 430
               OpName %main "main"
               OpName %i "i"
               OpName %block0 "block0"
               OpMemberName %block0 0 "inputs"
               OpName %_ ""
               OpName %block1 "block1"
               OpMemberName %block1 0 "outputs"
               OpName %__0 ""
               OpName %ResType "ResType"
               OpDecorate %_runtimearr_v2int64 ArrayStride 16
               OpMemberDecorate %block0 0 Offset 0
               OpDecorate %block0 BufferBlock
               OpDecorate %_ DescriptorSet 0
               OpDecorate %_ Binding 0
               OpDecorate %_runtimearr_v2int64_0 ArrayStride 16
               OpMemberDecorate %block1 0 Offset 0
               OpDecorate %block1 BufferBlock
               OpDecorate %__0 DescriptorSet 0
               OpDecorate %__0 Binding 1
       %void = OpTypeVoid
          %3 = OpTypeFunction %void
        %int = OpTypeInt 32 1
%_ptr_Function_int = OpTypePointer Function %int
      %int_0 = OpConstant %int 0
       %uint = OpTypeInt 32 0
       %int64 = OpTypeInt 64 1
 %v2int64 = OpTypeVector %int64 2
%_runtimearr_v2int64 = OpTypeRuntimeArray %v2int64
     %block0 = OpTypeStruct %_runtimearr_v2int64
%_ptr_Uniform_block0 = OpTypePointer Uniform %block0
          %_ = OpVariable %_ptr_Uniform_block0 Uniform
       %bool = OpTypeBool
     %uint_0 = OpConstant %uint 0
%_ptr_Uniform_int64 = OpTypePointer Uniform %int64
     %uint_1 = OpConstant %uint 1
%_runtimearr_v2int64_0 = OpTypeRuntimeArray %v2int64
     %block1 = OpTypeStruct %_runtimearr_v2int64_0
%_ptr_Uniform_block1 = OpTypePointer Uniform %block1
        %__0 = OpVariable %_ptr_Uniform_block1 Uniform
    %ResType = OpTypeStruct %int64 %int64
      %int_1 = OpConstant %int 1
       %main = OpFunction %void None %3
          %5 = OpLabel
          %i = OpVariable %_ptr_Function_int Function
               OpStore %i %int_0
               OpBranch %10
         %10 = OpLabel
               OpLoopMerge %12 %13 None
               OpBranch %14
         %14 = OpLabel
         %15 = OpLoad %int %i
         %22 = OpArrayLength %uint %_ 0
         %23 = OpBitcast %int %22
         %25 = OpSLessThan %bool %15 %23
               OpBranchConditional %25 %11 %12
         %11 = OpLabel
         %26 = OpLoad %int %i
         %29 = OpAccessChain %_ptr_Uniform_int64 %_ %int_0 %26 %uint_0
         %30 = OpLoad %int64 %29
         %31 = OpLoad %int %i
         %33 = OpAccessChain %_ptr_Uniform_int64 %_ %int_0 %31 %uint_1
         %34 = OpLoad %int64 %33
         %39 = OpLoad %int %i
         %40 = OpAccessChain %_ptr_Uniform_int64 %__0 %int_0 %39 %uint_0
         %41 = OpLoad %int %i
         %42 = OpAccessChain %_ptr_Uniform_int64 %__0 %int_0 %41 %uint_1
         %44 = OpSMulExtended %ResType %30 %34
         %45 = OpCompositeExtract %int64 %44 0
               OpStore %42 %45
         %46 = OpCompositeExtract %int64 %44 1
               OpStore %40 %46
               OpBranch %13
         %13 = OpLabel
         %47 = OpLoad %int %i
         %49 = OpIAdd %int %47 %int_1
               OpStore %i %49
               OpBranch %10
         %12 = OpLabel
               OpReturn
               OpFunctionEnd

The following MSL is generated:

#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

struct block0
{
    long2 inputs[1];
};

struct block1
{
    long2 outputs[1];
};

struct ResType
{
    long _m0;
    long _m1;
};

kernel void main0(constant uint* spvBufferSizeConstants [[buffer(29)]], device block0& _7 [[buffer(0)]], device block1& _9 [[buffer(1)]])
{
    constant uint& _7BufferSize = spvBufferSizeConstants[0];
    for (int i = 0; i < int((_7BufferSize - 0) / 16); i++)
    {
        ResType _44;
        _44._m0 = long(int(((device long*)&_7.inputs[i])[0u]) * int(((device long*)&_7.inputs[i])[1u]));
        _44._m1 = long(mulhi(int(((device long*)&_7.inputs[i])[0u]), int(((device long*)&_7.inputs[i])[1u])));
        ((device long*)&_9.outputs[i])[1u] = _44._m0;
        ((device long*)&_9.outputs[i])[0u] = _44._m1;
    }
}

However, the generated MSL incorrectly casts to 32 bit integers when the operation OpSMulExtended works with 64 bit integers.

Found while running CTS over MoltenVK. Failing tests:

dEQP-VK.spirv_assembly.instruction.compute.mul_extended.signed_64bit
dEQP-VK.spirv_assembly.instruction.compute.mul_extended.unsigned_64bit
HansKristian-Work commented 3 months ago

So SPIR-V is actually supposed to support 64x64 -> 128-bit multiplication here?

aitor-lunarg commented 3 months ago

So SPIR-V is actually supposed to support 64x64 -> 128-bit multiplication here?

That is my understanding. According to the spec https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSMulExtended the type needs to be an integer type but there's no mention to a limit in width, so I understand 64x64 ->128-bit multiplication is allowed.