shader-slang / slang

Making it easier to work with shaders
MIT License
1.81k stars 161 forks source link

Slangc crashes when using `[AutoPyBindCUDA]` and `[shader("compute")]` in the same file #4289

Closed eliemichel closed 1 month ago

eliemichel commented 1 month ago

Using slangtorch's loadModule on a slang file that defines both an [AutoPyBindCUDA] entry point and and a [shader("compute")] entry point (different functions) causes slangc to crash ("Compilation failed with error 3221225477" in Python traceback, no stderr from slangc).

Example:

[AutoPyBindCUDA]
[CUDAKernel]
void square(TensorView<float> input, TensorView<float> output)
{}

[shader("compute")]
[numthreads(1,1,1)]
void computeMain(uint3 threadId: SV_DispatchThreadID)
{}
ArielG-NV commented 1 month ago

Background on what is wrong:

  1. [AutoPyBindCUDA] and [CUDAKernel] adds HLSLExportDecoration. This means we won't discard the unused function.
  2. square uses TensorView but is not used by an EntryPoint. This means the capability system will not throw an error of hlsl/glsl/etc not supporting TensorView.
  3. TensorView then attempts to be emitted by GLSL/HLSL/etc at the very end of the compile since the function is not discarded. This causes an assert since not all targets support 'TensorView'.

possible solution from what I see:

eliemichel commented 1 month ago

Thanks for the insights! Note that slangc crash also occurs within the call to slangtorch's loadModule, so when targetting CUDA with TensorView capability.

My current workaround is to play with defines, which is fine although I end up with boilerplate code that I could not find a way to abstract away/automate. Something like:

// pseudo-code
interface ITensor { /* [...] */ }
// [...] Implementations of ITensor interface for TensorView and my custom StructuredBuffer-based tensor

void squareImpl(ITensor input, ITensor output) {
    // Here is the actual implementation of my kernel
    // [...]
}

// How to "automate" the generation of what's below?
#ifdef USE_PYTORCH
[AutoPyBindCUDA]
[CUDAKernel]
void square(TensorView<float> input, TensorView<float> output)
{
    squareImpl(/* [...] */);
}
#endif

#ifdef USE_COMPUTE
[shader("compute")]
[numthreads(1,1,1)]
void computeMain(uint3 threadId: SV_DispatchThreadID, StructuredBuffer<float> inputData, /* [...] */)
{
    squareImpl(/* [...] */);
}
#endif