shader-slang / slang

Making it easier to work with shaders
MIT License
1.78k stars 159 forks source link

Add support for matrix variants of `select` #4442

Open chaoticbob opened 1 week ago

chaoticbob commented 1 week ago

This may rely on #4395 being completed first.

The select intrinsic has a matrix variant that operates like the vector version of select but looks like this:

T select(S cond, T a, T b);

T is a matrix of supported arithmetic type S is a bool matrix

Currently, Slang produces errors like this when compiling the shader below:

shader.hlsl(56): error 30019: expected an expression of type 'bool', got 'matrix<bool,2,2>'
    float2x2 c2x2 = select(cond2x2, a2x2, b2x2);
                           ^~~~~~~
shader.hlsl(61): error 30019: expected an expression of type 'bool', got 'matrix<bool,2,3>'
    float2x3 c2x3 = select(cond2x3, a2x3, b2x3);
                           ^~~~~~~
shader.hlsl(66): error 30019: expected an expression of type 'bool', got 'matrix<bool,2,4>'
    float2x4 c2x4 = select(cond2x4, a2x4, b2x4);
                           ^~~~~~~

I couldn't find any documentation for this variant of the intrinsic. The shader below compiles and appears to generate valid DXIL and SPIR-V. Input matrices are row_major because DXC complains that column_major of bool matrices aren't supported.

Shader

cbuffer InputVars {
    row_major bool1x1 cond1x1;
    row_major bool1x2 cond1x2;
    row_major bool1x3 cond1x3;
    row_major bool1x4 cond1x4;

    row_major bool2x1 cond2x1;
    row_major bool2x2 cond2x2;
    row_major bool2x3 cond2x3;
    row_major bool2x4 cond2x4;

    row_major bool3x1 cond3x1;
    row_major bool3x2 cond3x2;
    row_major bool3x3 cond3x3;
    row_major bool3x4 cond3x4;

    row_major bool4x1 cond4x1;
    row_major bool4x2 cond4x2;
    row_major bool4x3 cond4x3;
    row_major bool4x4 cond4x4;
};

float4 main(uint index : A) : SV_TARGET{
    float s = 0;

    float1x1 a1x1 = (float1x1)0;
    float1x1 b1x1 = (float1x1)1;
    float1x1 c1x1 = select(cond1x1, a1x1, b1x1);
    s += c1x1[0][0];

    float1x2 a1x2 = (float1x2)0;
    float1x2 b1x2 = (float1x2)1;
    float1x2 c1x2 = select(cond1x2, a1x2, b1x2);
    s += c1x2[0][0];

    float1x3 a1x3 = (float1x3)0;
    float1x3 b1x3 = (float1x3)1;
    float1x3 c1x3 = select(cond1x3, a1x3, b1x3);
    s += c1x3[0][0];

    float1x4 a1x4 = (float1x4)0;
    float1x4 b1x4 = (float1x4)1;
    float1x4 c1x4 = select(cond1x4, a1x4, b1x4);
    s += c1x4[0][0];

    float2x1 a2x1 = (float2x1)0;
    float2x1 b2x1 = (float2x1)1;
    float2x1 c2x1 = select(cond2x1, a2x1, b2x1);
    s += c2x1[0][0];

    float2x2 a2x2 = (float2x2)0;
    float2x2 b2x2 = (float2x2)1;
    float2x2 c2x2 = select(cond2x2, a2x2, b2x2);
    s += c2x2[0][0];

    float2x3 a2x3 = (float2x3)0;
    float2x3 b2x3 = (float2x3)1;
    float2x3 c2x3 = select(cond2x3, a2x3, b2x3);
    s += c2x3[0][0];

    float2x4 a2x4 = (float2x4)0;
    float2x4 b2x4 = (float2x4)1;
    float2x4 c2x4 = select(cond2x4, a2x4, b2x4);
    s += c2x4[0][0];

    float3x1 a3x1 = (float3x1)0;
    float3x1 b3x1 = (float3x1)1;
    float3x1 c3x1 = select(cond3x1, a3x1, b3x1);
    s += c3x1[0][0];

    float3x2 a3x2 = (float3x2)0;
    float3x2 b3x2 = (float3x2)1;
    float3x2 c3x2 = select(cond3x2, a3x2, b3x2);
    s += c3x2[0][0];

    float3x3 a3x3 = (float3x3)0;
    float3x3 b3x3 = (float3x3)1;
    float3x3 c3x3 = select(cond3x3, a3x3, b3x3);
    s += c3x3[0][0];

    float3x4 a3x4 = (float3x4)0;
    float3x4 b3x4 = (float3x4)1;
    float3x4 c3x4 = select(cond3x4, a3x4, b3x4);
    s += c3x4[0][0];

    float4x1 a4x1 = (float4x1)0;
    float4x1 b4x1 = (float4x1)1;
    float4x1 c4x1 = select(cond4x1, a4x1, b4x1);
    s += c4x1[0][0];

    float4x2 a4x2 = (float4x2)0;
    float4x2 b4x2 = (float4x2)1;
    float4x2 c4x2 = select(cond4x2, a4x2, b4x2);
    s += c4x2[0][0];

    float4x3 a4x3 = (float4x3)0;
    float4x3 b4x3 = (float4x3)1;
    float4x3 c4x3 = select(cond4x3, a4x3, b4x3);
    s += c4x3[0][0];

    float4x4 a4x4 = (float4x4)0;
    float4x4 b4x4 = (float4x4)1;
    float4x4 c4x4 = select(cond4x4, a4x4, b4x4);
    s += c4x4[0][0];

    return float4(s, 0, 0, 0);    
}
chaoticbob commented 4 days ago

FYI: There also appears to be a version of this for pre-HLSL2021. Not sure if Slang will care to support, the syntax looks like this:

float2x3 result2x3 = cond2x3 ? true2x3 : false2x3;