shader-slang / slang

Making it easier to work with shaders
MIT License
1.78k stars 159 forks source link

Sanity check for compiling multiple shaders in a single file to SPIR-V library with Slang API #4325

Closed chaoticbob closed 3 weeks ago

chaoticbob commented 3 weeks ago

Was wondering if I could get a quick sanity on Slang API correctness when compiling a single source file containing multiple shaders to SPIR-V library using the Slang API.

The specific case for this is compiling several ray tracing shaders that live in the same source file (or string in this case). I looked through the ray-tracing-pipeline sample and found that it eventually does a per entry compile to a module for each entry point.

What I was trying to accomplish effectively the same thing as using -profile lib_6_3 on the command line without specifying any entry points. After debug walking through how slangc handles, I was able to clobber something together to accomplish what I wanted. However, I'm not sure if it's entirely correct.

Was wondering if someone familiar with the Slang API could sanity check the code and provide any corrections? The compile to library starts at if (isTargetLibrary).

Related question: if I create a compile request using ISession::createCompileRequest, do I need to manually destroy the compile request using spDestroyCompileRequest?

Thanks.

CompileResult CompileSlang(
    const std::string&     shaderSource,
    const std::string&     entryPoint,
    const std::string&     profile,
    const CompilerOptions& options,
    std::vector<uint32_t>* pSPIRV,
    std::string*           pErrorMsg,
    const std::string&     bytecodeFilePrefix)
{
    // Bail if entry point is empty and we're not compiling to a library
    bool isTargetLibrary = profile.starts_with("lib_6_");
    if (!isTargetLibrary && entryPoint.empty()) {
        return COMPILE_ERROR_INVALID_ENTRY_POINT;
    }

    Slang::ComPtr<slang::IGlobalSession> globalSession;
    if (SLANG_FAILED(slang::createGlobalSession(globalSession.writeRef())))
    {
        return COMPILE_ERROR_INTERNAL_COMPILER_ERROR;
    }

    slang::TargetDesc targetDesc = {};
    targetDesc.format            = SLANG_SPIRV;
    targetDesc.profile           = globalSession->findProfile(profile.c_str());
    targetDesc.flags             = SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY;

    targetDesc.flags |= SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM;

    // Must be set in target desc for now
    targetDesc.forceGLSLScalarBufferLayout = true;

    // Compiler options for Slang
    std::vector<slang::CompilerOptionEntry> compilerOptions;
    {    
        // Force Slang language to prevent any accidental interpretations as GLSL or HLSL
        {
            slang::CompilerOptionEntry entry = {slang::CompilerOptionName::Language};
            entry.value.stringValue0         = "slang";

            compilerOptions.push_back(entry);
        }

        // Force "main" entry point if requested
        if (!options.ForceEntryPointMain)
        {
            compilerOptions.push_back(
                slang::CompilerOptionEntry{
                    slang::CompilerOptionName::VulkanUseEntryPointName,
                    slang::CompilerOptionValue{slang::CompilerOptionValueKind::Int, 1}
            });
        }

        // Force scalar block layout - this gets overwritten by forceGLSLScalarBufferLayout in
        // the target desc currently. So we just set it there.
        //
        {
            compilerOptions.push_back(
                slang::CompilerOptionEntry{
                    slang::CompilerOptionName::GLSLForceScalarLayout,
                    slang::CompilerOptionValue{slang::CompilerOptionValueKind::Int, 1}
            });           
        }
    }

    slang::SessionDesc sessionDesc       = {};
    sessionDesc.targets                  = &targetDesc;
    sessionDesc.targetCount              = 1;
    sessionDesc.defaultMatrixLayoutMode  = SLANG_MATRIX_LAYOUT_COLUMN_MAJOR;
    sessionDesc.compilerOptionEntries    = compilerOptions.data();
    sessionDesc.compilerOptionEntryCount = static_cast<uint32_t>(compilerOptions.size());

    Slang::ComPtr<slang::ISession> compileSession;
    if (SLANG_FAILED(globalSession->createSession(sessionDesc, compileSession.writeRef())))
    {
        return COMPILE_ERROR_INTERNAL_COMPILER_ERROR;
    }

    Slang::ComPtr<slang::IBlob> spirvCode;
    if (isTargetLibrary) {
        //
        // NOTE: This may not be the most correct way to do it, but it works for now
        //

        // Create compile request
        std::unique_ptr<SlangCompileRequest, void(*)(SlangCompileRequest*)> compileRequest(nullptr, nullptr);
        {        
            SlangCompileRequest* pCompileRequest = nullptr;
            auto slangRes = compileSession->createCompileRequest(&pCompileRequest);
            if (SLANG_FAILED(slangRes)) {
                return COMPILE_ERROR_INTERNAL_COMPILER_ERROR;
            }

            compileRequest = std::unique_ptr<SlangCompileRequest, void(*)(SlangCompileRequest*)>(pCompileRequest, spDestroyCompileRequest);
        }

        // Add translation unit
        auto index = compileRequest->addTranslationUnit(SLANG_SOURCE_LANGUAGE_SLANG, nullptr);
        compileRequest->addTranslationUnitSourceString(index, "grex-path", shaderSource.c_str());

        // Compile
        auto slangRes = compileRequest->compile();
        if (SLANG_FAILED(slangRes)) {
            if (pErrorMsg != nullptr) {
                Slang::ComPtr<slang::IBlob> diagBlob;

                slangRes = compileRequest->getDiagnosticOutputBlob(diagBlob.writeRef());
                if (SLANG_SUCCEEDED(slangRes)) 
                {
                    *pErrorMsg = std::string(static_cast<const char*>(diagBlob->getBufferPointer()), diagBlob->getBufferSize());
                }
                else {
                    // Something has gone really wrong
                    assert(false && "failed to get diagnostic output blob");
                }
            }                

            return COMPILE_ERROR_COMPILE_FAILED;
        }

        // Get SPIR-V 
        slangRes = compileRequest->getTargetCodeBlob(0, spirvCode.writeRef());
        if (SLANG_FAILED(slangRes)) {
            if (pErrorMsg != nullptr) {
                *pErrorMsg = "unable to retrieve SPIR-V blob for library";
            }

            return COMPILE_ERROR_INTERNAL_COMPILER_ERROR;
        }
    }
    else {
        // Load source
        slang::IModule* pSlangModule = nullptr;
        {
            Slang::ComPtr<slang::IBlob> diagBlob;

            pSlangModule = compileSession->loadModuleFromSourceString("grex-module", nullptr, shaderSource.c_str(), diagBlob.writeRef());
            if (pSlangModule == nullptr)
            {
                if (pErrorMsg != nullptr)
                {
                    *pErrorMsg = std::string(static_cast<const char*>(diagBlob->getBufferPointer()), diagBlob->getBufferSize());
                }

                return COMPILE_ERROR_INTERNAL_COMPILER_ERROR;
            }
        }

        // Components
        std::vector<slang::IComponentType*> components;
        components.push_back(pSlangModule);

        // Entry points
        if (!entryPoint.empty()) {
            Slang::ComPtr<slang::IEntryPoint> slangEntryPoint;
            if (SLANG_FAILED(pSlangModule->findEntryPointByName(entryPoint.c_str(), slangEntryPoint.writeRef())))
            {
                return COMPILE_ERROR_INTERNAL_COMPILER_ERROR;
            }
            components.push_back(slangEntryPoint);
        }
        else {
            SlangInt32 slangEntryPointCount = pSlangModule->getDefinedEntryPointCount();
            for (SlangInt32 i = 0; i < slangEntryPointCount; ++i) {
                ComPtr<slang::IEntryPoint> slangEntryPoint;
                if (SLANG_FAILED(pSlangModule->getDefinedEntryPoint( i, &slangEntryPoint)))
                {
                    return COMPILE_ERROR_INTERNAL_COMPILER_ERROR;
                }
                components.push_back(slangEntryPoint.Get());
            }
        }

        Slang::ComPtr<slang::IComponentType> composedProgram;
        {
            Slang::ComPtr<slang::IBlob> diagBlob;

            auto slangRes = compileSession->createCompositeComponentType(
                components.data(),
                components.size(),
                composedProgram.writeRef(),
                diagBlob.writeRef());
            if (SLANG_FAILED(slangRes))
            {
                if (pErrorMsg != nullptr)
                {
                    *pErrorMsg = std::string(static_cast<const char*>(diagBlob->getBufferPointer()), diagBlob->getBufferSize());
                }

                return COMPILE_ERROR_COMPILE_FAILED;
            }
        }

        // Get SPIR-V 
        {
            Slang::ComPtr<slang::IBlob> diagBlob;

            auto slangRes = composedProgram->getEntryPointCode(
                0, // entryPointIndex,
                0, // targetIndex,
                spirvCode.writeRef(),
                diagBlob.writeRef());
            if (SLANG_FAILED(slangRes))
            {
                if (pErrorMsg != nullptr)
                {
                    *pErrorMsg = std::string(static_cast<const char*>(diagBlob->getBufferPointer()), diagBlob->getBufferSize());
                }

                return COMPILE_ERROR_LINK_FAILED;
            }
        }
    }

    const char* pBuffer    = static_cast<const char*>(spirvCode->getBufferPointer());
    size_t      bufferSize = static_cast<size_t>(spirvCode->getBufferSize());
    size_t      wordCount  = bufferSize / 4;

    pSPIRV->resize(wordCount);
    memcpy(pSPIRV->data(), pBuffer, bufferSize);

    return COMPILE_SUCCESS;
}
csyonghe commented 3 weeks ago

The compile request API is deprecated. I realize that the IComponentType interface is missing a method. We need to add getWholeProgramCode so you can do this:

auto module = compileSession->loadModule(...);
auto linkedModule = module->link();
auto spvLibCode = linkedModule->getWholeProgramCode(..);