shader-slang / slang-rhi

Slang Render Hardware Interface
Other
25 stars 7 forks source link

Shader object refactor #87

Open skallweitNV opened 1 month ago

skallweitNV commented 1 month ago

Introduction

This document describes a new API for shader objects in slang-rhi.

The main goals are:

IShaderObject

Similar to the implementation in gfx, we have a IShaderObject interface that represents shader objects:

class IShaderObject
{
public:
    /// Return the associated element type layout.
    virtual slang::TypeLayoutReflection* getElementTypeLayout() = 0;
    /// Return the container type.
    virtual ShaderObjectContainerType getContainerType() = 0;
    /// Return the number of entry points (if this is a root shader object).
    virtual Count getEntryPointCount() = 0;
    /// Return an entry point by index.
    virtual Result getEntryPoint(GfxIndex index, IShaderObject** entryPoint) = 0;

    /// Set uniform data.
    virtual Result setData(ShaderOffset offset, const void *data, Size size) = 0;
    /// Get uniform data.
    virtual Result getData(ShaderOffset offset, void *data, Size size) = 0;

    /// Set a binding.
    virtual Result setBinding(ShaderOffset offset, Binding binding) = 0;
    /// Get a binding.
    virtual Result getBinding(ShaderOffset offset, Binding* binding) = 0;

    /// Set a sub-object.
    virtual Result setObject(ShaderOffset offset, IShaderObject* object) = 0;
    /// Get a sub-object.
    virtual Result getObject(ShaderOffset offset, IShaderObject** object) = 0;

    /// Freeze the shader object, making it immutable.
    /// Any calls to modify the shader object after this will result in an error.
    virtual Result freeze() = 0;
};

The main difference is that IShaderObject objects become immutable after calling the freeze() method. Shader objects can only be bound to a command encoder after they have been frozen. Shader objects can only be assigned as sub-objects in other shader objects if they are frozen.

Frozen shader objects cannot under any circumstances be unfrozen.

To create shader objects, there are a few factory methods on IDevice:

class IDevice
{
public:
    /// Create a new shader object from a given slang type.
    virtual Result createShaderObject(
        slang::ISession* slangSession,
        slang::TypeReflection* type,
        ShaderObjectContainerType container,
        IShaderObject** outObject
    ) = 0;

    /// Create a new shader object by copying an existing shader object.
    /// The new shader object will be mutable.
    virtual Result createShaderObject(IShaderObject* object, IShaderObject** outObject) = 0;

    /// Create a new root shader object for the given shader program.
    virtual Result createRootShaderObject(IShaderProgram* program, IShaderObject** outObject) = 0;
};

Basic example

Slang code:

interface ICamera { Ray getRay(float2 uv);};

struct PinholeCamera : ICamera
{
    float3 position;
    float3 direction;
    float3 up;
    float2 fov;
    float2 resolution;

    Ray getRay(float2 uv) { /* ... */ }
}

struct Scene
{
    StructuredBuffer<Vertex> vertexBuffer;
    StructuredBuffer<Index> indexBuffer;
    ICamera camera;
};

[[shader(compute)]]
void render(uniform ParameterBlock<Scene> scene, uniform uint iteration)
{}

Host code:

// Load the shader program.
ComPtr<IShaderProgram> program = device->loadProgram("test.slang", "main");

// Create a shader object for the camera.
// This object will be used to specialize the program.
ComPtr<IShaderObject> cameraObject = device->createShaderObject(program->getReflection()->getType("PinholeCamera"));
{
    ShaderCursor cursor(cameraObject);
    cursor["position"] = float3(0, 0, 0);
    ...
}
// Freeze the camera shader object so we can use it as a sub-object.
cameraObject->freeze();

// Create a shader object for the scene.
ComPtr<IShaderObject> sceneObject = device->createShaderObject(program->getReflection()->getType("Scene"));
{
    ShaderCursor cursor(sceneObject);
    cursor["vertexBuffer"] = vertexBuffer;
    cursor["indexBuffer"] = indexBuffer;
    cursor["camera"] = cameraObject; // NOTE: Error if cameraObject was not frozen!
}
sceneObject->freeze();

// Create the root shader object.
ComPtr<IShaderObject> rootObject = device->createRootShaderObject(program);
{
    ShaderCursor cursor(rootObject)
    cursor["scene"] = sceneObject;
}
rootObject->freeze();

// With the root object done, we can now specialize our program.
ComPtr<IShaderProgram> specializedProgram = program->specialize(program, rootObject);

// Create a compute pipeline.
// NOTE: Creating a pipeline for an unspecialized program would be an error!
ComPtr<IComputePipeline> pipeline = device->createComputePipeline(specializedProgram);

// Submit a single compute dispatch.
ComPtr<ICommandEncoder> encoder = device->getQueue()->createCommandEncoder();
ComputeState state;
state.pipeline = pipeline;
state.rootObject = rootObject; // NOTE: Error if rootObject is not frozen.
encoder->setComputeState(state);
encoder->dispatchCompute(1, 1, 1);
device->getQueue()->submit(encoder->finish());

// Submit multiple dispatches on the same command list, with modified root objects.
ComPtr<ICommandEncoder> encoder = device->getQueue()->createCommandEncoder();
for (int i = 0; i < 100; ++i)
{
    ComPtr<IShaderObject> modifiedRootObject = device->createShaderObject(rootObject);
    {
        ShaderCursor cursor(modifiedRootObject);
        cursor["iteration"] = i;
    }
    modifiedRootObject->freeze();
    ComputeState state;
    state.pipeline = pipeline;
    state.rootObject = modifiedRootObject;
    encoder->setComputeState(state);
    encoder->dispatchCompute(1, 1, 1);
}
device->getQueue()->submit(encoder->finish());

Implementation details

ShaderObject

Shader objects are implemented in a backend agnostic way. The main purpose of shader objects is to hold all the resources, sub-objects and uniform data assigned to them. Binding ranges are used to map shader offsets to a linear array of binding slots. Each binding slot contains a reference to a resource and additional data (e.g. buffer range, format, etc.).

class ShaderObject : public IShaderObject, public ComObject
{
public:
    void init()
    {
        // 1. Enumerate all binding ranges and populate m_bindingTypeToStartIndex and m_bindings.
        // 2. Enumerate all sub-objects and create a shader object for each sub-object (recursively).
        // 3. Allocate memory for uniform data.
    }

    Result setData(ShaderOffset offset, const void *data, Size size) override
    {
        // 1. Return error if the shader object is frozen.
        // 2. Copy the data into the uniform data buffer.
    }

    Result getData(ShaderOffset offset, void *data, Size size) override
    {
        // Copy the data from the uniform data buffer.
    }

    Result setBinding(ShaderOffset offset, Binding binding) override
    {
        // 1. Return error if the shader object is frozen.
        // 2. Find the binding range for the given offset.
        // 3. Copy the binding into the bindings array.
    }

    Result getBinding(ShaderOffset offset, Binding* binding) override
    {
        // Find the binding for the given offset and return it.
    }

    Result setObject(ShaderOffset offset, IShaderObject* object) override
    {
        // 1. Return error if the shader object is frozen.
        // 2. Find the index of the sub-object for the given offset.
        // 3. Set the object into the sub-objects array.
    }

    Result getObject(ShaderOffset offset, IShaderObject** object) override
    {
        // Find the sub-object for the given offset and return it.
    }

    Result freeze() override
    {
        // 1. Return error if the shader object is already frozen.
        // 2. Freeze all sub-objects (recursively).
        // 3. Freeze the shader object.
    }

private:
    struct BindingSlot
    {
        /// The bound resource.
        RefPtr<Resource> resource;
        /// Additional data.
        union
        {
            struct
            {
                BufferRange range;
                Format format;
            } buffer;
            // ...
        };
    }

    /// True if the shader object is frozen.
    bool m_frozen = false; 
    /// Map from binding type to start index in the bindings array.
    std::array<uint32_t, slang::BindingType::Count> m_bindingTypeToStartIndex;
    /// List of bindings.
    std::vector<BindingSlot> m_bindings;
    /// List of sub-objects.
    std::vector<RefPtr<ShaderObject>> m_objects;
    /// Uniform data.
    std::vector<uint8_t> m_data;
};

Shallow copy

When copying shader objects, we start by only copying the root object. Sub-objects initially reference the same memory as the original shader object. When a sub-object is modified (through getObject), we create a copy of the sub-object. This way, we can avoid copying the entire shader object tree when only a small part of it is modified. We also track the original shader objects we copied from to allow backends to reuse computation of previous shader objects.

class ShaderObject : public IShaderObject, public ComObject
{
public:
    Result initFromOther(IShaderObject* other) override
    {
        // 1. Copy all data from `other`.
        // 2. Set `m_is_copy` to true.
        // 3. Assign the original shader object to `m_original`.
    }

    Result getObject(ShaderOffset offset, IShaderObject** object) override
    {
        // 1. Find the sub-object for the given offset.
        // 2. If the sub-object is frozen (i.e. still referencing the original), replace the sub-object with a copy.
        // 3. Return the new sub-object.
    }

private:
    /// True if the shader object is a copy.
    bool m_is_copy = false;
    /// The shader object we copied from.
    RefPtr<ShaderObject> m_original;
};

Backend data

Each backend has a different way for binding shader objects. Once a shader object is frozen, we can create a per-backend data structure that contains all the information needed to bind the shader object to a command encoder. These objects can be light-weight, as they only need to store the information needed to bind the shader object.