shader-slang / slang

Making it easier to work with shaders
MIT License
1.78k stars 159 forks source link

Capability system causes issues with `__dispatch_kernel()` for slang-torch use-cases #4517

Open saipraveenb25 opened 4 days ago

saipraveenb25 commented 4 days ago

We have two ways to use slang-torch: [AutoPyBindCUDA] which generates the kernel launch and a manual appraoch that allows the user to directly launch kernels through __dispatch_kernel()

Currently, the capability system is causing all the manual-mode tests to fail in the slang-torch repository. Here is a smoke test:

float computeOutputValue(TensorView<float> A, uint2 loc)
{
    return A[loc] * 2;
}

[CudaKernel]
void mul_kernel(TensorView<float> A, TensorView<float> result)
{
    uint2 location = (cudaBlockDim() * cudaBlockIdx() + cudaThreadIdx()).xy;
    result[location] = computeOutputValue(A, location);
}

[TorchEntryPoint]
TorchTensor<float> multiply(TorchTensor<float> A)
{
    var result = TorchTensor<float>.zerosLike(A);
    let blockCount = uint3(1);
    let groupSize = uint3(A.size(0), A.size(1), 1);

    __dispatch_kernel(mul_kernel, blockCount, groupSize)(A, result);
    return result;
}

which causes the following error:

error 36100: 'mul_kernel' requires capability 'textualTarget + cuda + vertex | textualTarget + cuda + fragment | textualTarget + cuda + compute | textualTarget + cuda + hull | textualTarget + cuda + domain | textualTarget + cuda + geometry | textualTarget + cuda + raygen | textualTarget + cuda + intersection | textualTarget + cuda + anyhit | textualTarget + cuda + closesthit | textualTarget + cuda + miss | textualTarget + cuda + mesh | textualTarget + cuda + amplification | textualTarget + cuda + callable' that is conflicting with the 'multiply''s current capability requirement 'textualTarget + cpp + vertex | textualTarget + cpp + fragment | textualTarget + cpp + compute | textualTarget + cpp + hull | textualTarget + cpp + domain | textualTarget + cpp + geometry | textualTarget + cpp + raygen | textualTarget + cpp + intersection | textualTarget + cpp + anyhit | textualTarget + cpp + closesthit | textualTarget + cpp + miss | textualTarget + cpp + mesh | textualTarget + cpp + amplification | textualTarget + cpp + callable'.
    __dispatch_kernel(mul_kernel, blockCount, groupSize)(A, result);
                      ^~~~~~~~~~