KhronosGroup / SPIRV-Cross

SPIRV-Cross is a practical tool and library for performing reflection on SPIR-V and disassembling SPIR-V back to high level languages.
Apache License 2.0
2.02k stars 557 forks source link

[MSL] Translate `SPV_NV_mesh_shader` Mesh shaders to Metal 3 mesh shaders #1962

Open expenses opened 2 years ago

expenses commented 2 years ago

SPV_NV_mesh_shader seems to map quite well to Metal 3 mesh shaders and we'd need to be able to translate them for https://github.com/KhronosGroup/MoltenVK/issues/1618.

I looked into implementing this for naga and got fairly far, but ran into a potential driver issue without much ability to debug it. I'm dumping some thoughts on how mesh shaders could be translated for SPIRV-Cross here.

Here's a sample mesh shader, adapted from http://zone.dog/braindump/mesh_shaders/:

#version 450
#extension GL_NV_mesh_shader : require
layout(local_size_x=1) in;
layout(max_vertices=4, max_primitives=2) out;
layout(triangles) out;

out gl_MeshPerVertexNV {
    vec4 gl_Position;
} gl_MeshVerticesNV[];

out uint gl_PrimitiveIndicesNV[];
layout(location = 0) out float vertexColor[];

void main()
{
    gl_MeshVerticesNV[0].gl_Position = vec4(-1.0, -1.0, 0.0, 1.0); // Upper Left
    gl_MeshVerticesNV[1].gl_Position = vec4( 1.0, -1.0, 0.0, 1.0); // Upper Right
    gl_MeshVerticesNV[2].gl_Position = vec4(-1.0,  1.0, 0.0, 1.0); // Bottom Left
    gl_MeshVerticesNV[3].gl_Position = vec4( 1.0,  1.0, 0.0, 1.0); // Bottom Right

    vertexColor[0] = 0.0;
    vertexColor[1] = 0.25;
    vertexColor[2] = 0.5;
    vertexColor[3] = 0.75;

    gl_PrimitiveIndicesNV[0] = 0;
    gl_PrimitiveIndicesNV[1] = 1;
    gl_PrimitiveIndicesNV[2] = 2;
    gl_PrimitiveIndicesNV[3] = 2;
    gl_PrimitiveIndicesNV[4] = 1;
    gl_PrimitiveIndicesNV[5] = 3;
    gl_PrimitiveCountNV = 2;
}

The SPIR-V disassembly for this shader compiled with glslc looks like:

; SPIR-V
; Version: 1.0
; Generator: Google Shaderc over Glslang; 10
; Bound: 62
; Schema: 0
               OpCapability MeshShadingNV
               OpExtension "SPV_NV_mesh_shader"
          %1 = OpExtInstImport "GLSL.std.450"
               OpMemoryModel Logical GLSL450
               OpEntryPoint MeshNV %main "main" %gl_MeshVerticesNV %vertexColor %gl_PrimitiveIndicesNV %gl_PrimitiveCountNV
               OpExecutionMode %main LocalSize 1 1 1
               OpExecutionMode %main OutputVertices 4
               OpExecutionMode %main OutputPrimitivesNV 2
               OpExecutionMode %main OutputTrianglesNV
               OpSource GLSL 450
               OpSourceExtension "GL_GOOGLE_cpp_style_line_directive"
               OpSourceExtension "GL_GOOGLE_include_directive"
               OpSourceExtension "GL_NV_mesh_shader"
               OpName %main "main"
               OpName %gl_MeshPerVertexNV "gl_MeshPerVertexNV"
               OpMemberName %gl_MeshPerVertexNV 0 "gl_Position"
               OpName %gl_MeshVerticesNV "gl_MeshVerticesNV"
               OpName %vertexColor "vertexColor"
               OpName %gl_PrimitiveIndicesNV "gl_PrimitiveIndicesNV"
               OpName %gl_PrimitiveCountNV "gl_PrimitiveCountNV"
               OpMemberDecorate %gl_MeshPerVertexNV 0 BuiltIn Position
               OpDecorate %gl_MeshPerVertexNV Block
               OpDecorate %vertexColor Location 0
               OpDecorate %gl_PrimitiveIndicesNV BuiltIn PrimitiveIndicesNV
               OpDecorate %gl_PrimitiveCountNV BuiltIn PrimitiveCountNV
               OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
...
%gl_PrimitiveCountNV = OpVariable %_ptr_Output_uint Output
     %v3uint = OpTypeVector %uint 3
%gl_WorkGroupSize = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_1
       %main = OpFunction %void None %3
          %5 = OpLabel
...
               OpStore %gl_PrimitiveCountNV %uint_2
               OpReturn
               OpFunctionEnd

Metal mesh shaders are marked with [[mesh]] and take a mesh<V, P, NV, NP, T> as input where

All together you have an input of metal::mesh<Vertex, void, 4, 2, metal::topology::triangle> mesh

The current output of SPIRV-Cross looks like this: (truncated for bevity)

#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"

#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

template<typename T, size_t Num>
struct spvUnsafeArray
{
,,,
};

constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);

struct main0_out
{
    float vertexColor_0;
    float vertexColor_1;
    float vertexColor_2;
    float vertexColor_3;
    float4 gl_Position;
    float4 gl_Position_1;
    float4 gl_Position_2;
    float4 gl_Position_3;
};

unknown main0_out main0()
{
    main0_out out = {};
    spvUnsafeArray<float, 4> vertexColor = {};
    _RESERVED_IDENTIFIER_FIXUP_gl_MeshVerticesNV[0].out.gl_Position = float4(-1.0, -1.0, 0.0, 1.0);
    _RESERVED_IDENTIFIER_FIXUP_gl_MeshVerticesNV[1].out.gl_Position = float4(1.0, -1.0, 0.0, 1.0);
    _RESERVED_IDENTIFIER_FIXUP_gl_MeshVerticesNV[2].out.gl_Position = float4(-1.0, 1.0, 0.0, 1.0);
    _RESERVED_IDENTIFIER_FIXUP_gl_MeshVerticesNV[3].out.gl_Position = float4(1.0, 1.0, 0.0, 1.0);
    vertexColor[0] = 0.0;
    vertexColor[1] = 0.25;
    vertexColor[2] = 0.5;
    vertexColor[3] = 0.75;
    gl_BuiltIn_5276[0] = 0u;
    gl_BuiltIn_5276[1] = 1u;
    gl_BuiltIn_5276[2] = 2u;
    gl_BuiltIn_5276[3] = 2u;
    gl_BuiltIn_5276[4] = 1u;
    gl_BuiltIn_5276[5] = 3u;
    gl_BuiltIn_5275 = 2u;
    out.vertexColor_0 = vertexColor[0];
    out.vertexColor_1 = vertexColor[1];
    out.vertexColor_2 = vertexColor[2];
    out.vertexColor_3 = vertexColor[3];
    return out;
}

You probably want to end up with an output that looks like:

[[mesh]] void main(
  metal::mesh<main_Output, void, 4, 2, metal::topology::triangle> mesh
) {
    spvUnsafeArray<float, 4> vertexColor = {};
    spvUnsafeArray<float4, 4> gl_MeshPerVertexNV = {};
    gl_MeshPerVertexNV[0] = float4(-1.0, -1.0, 0.0, 1.0);
    gl_MeshPerVertexNV[1] = float4(1.0, -1.0, 0.0, 1.0);
    gl_MeshPerVertexNV[2] = float4(-1.0, 1.0, 0.0, 1.0);
    gl_MeshPerVertexNV[3] = float4(1.0, 1.0, 0.0, 1.0);
    vertexColor[0] = 0.0;
    vertexColor[1] = 0.25;
    vertexColor[2] = 0.5;
    vertexColor[3] = 0.75;
    mesh.set_index(0, 0u);
    mesh.set_index(1, 1u);
    mesh.set_index(2, 2u);
    mesh.set_index(3, 2u);
    mesh.set_index(4, 1u);
    mesh.set_index(5, 3u);
    mesh.set_primitive_count(2u);

    for (uint i = 0; i < 4u; i++) {
        mesh.set_vertex(i, { gl_MeshPerVertexNV[i], vertexColor[i] });
    }

    return out;
}
HansKristian-Work commented 2 years ago

Given that EXT mesh shaders are in the works, I'd rather wait for that.

oscarbg commented 2 years ago

yes, VK_EXT_mesh_shader is here! just joining to get updates.. :-)

Try commented 1 year ago

Made prototype on a weekend: https://github.com/Try/SPIRV-Cross/commit/d83153a1d85277128f4dd66c763b9f4af110ca38 Proper varyings, per-primitive outputs are still TODO, as well as many other things

Here is a code-gen example: https://shader-playground.timjones.io/85579774a9971f99157fb3c41f8d215e

Some design questions:

  1. Metal shader, unlike glsl, must provide per-vertex output as one write mesh.set(i, all_vertex); Current workaround for me is to declare all outputs as threadgroup variables and then run post process step:

    // ideally max_vertex<=thread_count and loop should have exactly one iteration
    for (uint spvI = gl_LocalInvocationIndex, spvThreadCount = (gl_WorkGroupSize.x*gl_WorkGroupSize.y*gl_WorkGroupSize.z); 
         spvI < 24 /* max_vertex */; spvI += spvThreadCount)
    {
        spvPerVertex spvV = {}; // merged struct with gl_Position and varyings
        spvV.gl_MeshVerticesEXT = gl_MeshVerticesEXT[spvI];
        spvV.gl_MeshVerticesEXT.gl_Position.y = -spvV.gl_MeshVerticesEXT.gl_Position.y;
        spvV.perVertex = perVertex[spvI];
        spvMesh.set_vertex(spvI, spvV);
    }

    I know that post-process write-out is not desirable, but so far that is the only idea.

  2. Copy-loop doesn't really work - it assumes that all threads are still active, when loop need to run Can be fixed by extracting application code into mesh_main (HLSL-like), and have post-process in a real main. I noticed, that MSL backend doesn't do proxy-main, as HLSL - is there a reason to have this design? In other words, what else can break?

Try commented 1 year ago

Drafted PR: https://github.com/KhronosGroup/SPIRV-Cross/pull/2074