[vulkan] Support VK_KHR_buffer_device_address and PhysicalStorageBuffer #13945

Open powderluv opened 1 year ago

powderluv commented 1 year ago

Request description

Branching off issue to see what it would take to implement VK_KHR_buffer_device_address and PhysicalStorageBuffer based access to Vulkan devices. Increasing we are dealing with very large tensors (>4GB) and maxStorageBufferRange is limited to 4GB. While we explore options with this feature request is to see what it would take to move us to using VK_KHR_buffer_device_address.

We are seeing increasing model sizes Stable Diffusion (768x768) , LLaMA upto 65B etc that we are unable to run on our Vulkan backend today without doing a multi-process hack. We are also starting to see 16GB+ VRAM allocations on mobile SoC devices so this a requirement across the board for vulkan devices.

Some other references: ? There were some references to it from

@antiagainst @benvanik @stellaraccident

What component(s) does this issue relate to?

No response

Additional context

No response

allieculp commented 1 year ago

Adding @antiagainst to take a look when you can.

antiagainst commented 1 year ago

This would be a major change w.r.t. how we handle buffers across runtime and kernel. What the device buffer address extension does is enabling querying int64 physical GPU addresses for storage buffers so that we can populate them in uniform buffers or push constants and then let kernels directly load them as buffer pointers and do load/store thereafter. So it's pretty substantial, at least at the conceptual level.

Now, regarding what needs to be changed to support this extension, a few big parts:

I'd need to talk with @benvanik to get a more detailed picture w.r.t. what's need to be changed, esp. on the runtime side.

benvanik commented 1 year ago

Yeah, lots tangled up here - we'll need to break it down. There's some easier solutions and some harder ones for sure :) I suspect we'll build ArgumentBuffer-like thing in the Vulkan HAL driver and map descriptor sets into that instead of native descriptor sets.

benvanik commented 1 year ago

(this was how I was imagining enabling secondary command buffer buffer substitution in iree_hal_command_buffer_execute - so doing it may get us that too!)

allieculp commented 1 year ago

Setting as P2 for now but leaving open for continued conversations and task lists etc. Please edit as needed.

benvanik commented 1 year ago

Took a look at what would be needed. @antiagainst articulated the major parts and then there's a few details:

The idea would be to have each command buffer have a growable set of staging buffers with uniform-buffer-upload semantics (something we'd have to benchmark, but usually device-local|host-visible|host-coherent) and as the command buffer is recorded and push_descriptor_sets is called we'd slice off some of the current staging buffer, scribble in our buffer info, and then bind the staging buffer with a dynamic offset as a normal descriptor set operation using DescriptorSetArena as today. The shaders would have a buffer declared for the descriptors and access all descriptors indirectly through it. Something like:

iree_hal_vulkan_direct_command_buffer_push_descriptor_set(...) {
  if (iree_hal_vulkan_native_pipeline_layout_has_indirect_access(pipeline_layout, set)) {
    ... stash on command buffer mirror of descriptor state ...
    ... mark parameters as dirty ...
  } else {
    // existing descriptor set binding path
iree_hal_vulkan_direct_command_buffer_dispatch(...) {
  if (parameters are dirty) {
    if (iree_hal_vulkan_native_pipeline_layout_has_indirect_access(pipeline_layout, set)) {
       // upload a new parameters chunk by flushing the command buffer descriptor state for the pipeline layout
       // this may allocate a new staging buffer if the prior one is exhausted
       iree_hal_vulkan_direct_command_buffer_append_dispatch_parameters(pipeline_layout, &staging_descriptor_set, &staging_offset);
       // bind the root descriptor with the dynamic offset of the dispatch - should be cheap
       vkCmdBindDescriptorSets(staging_descriptor_set, staging_offset);
       ... reset parameters dirty flag ...

And the shader:

layout(buffer_reference, std430, buffer_reference_align = 16) buffer binding_f32_t {
  float data[];
layout(set = 3, binding = 0) buffer root_set_0_t {
  binding_f32_t binding_0;
  binding_f32_t binding_1;
  binding_f32_t binding_2;
} root_set_0;
void main() {[0];  // access...
benvanik commented 1 year ago

Oh the other thing this intersects with is secondary indirect command buffers (iree_hal_command_buffer_execute_commands) - those are recorded with placeholders and then when executed the placeholders get updated to the bindings passed in via iree_hal_buffer_binding_table_t. There's a bit more bookkeeping required such that we map what a command-buffer-global binding slot is to the locations in the staging buffer that need to have that value populated. In the above we can write the device addresses in when flushing the parameters but here we'd instead scribble aside the offset into the parameter buffer where the address should be written as we don't actually have it at the time of recording. When a secondary command buffer is scheduled with iree_hal_command_buffer_execute_commands the hosting primary command buffer would use one or more vkCmdUpdateBuffer to populate the parameters in stream order and so long as we didn't allow multiple overlapping executions of the same secondary buffer (we'll need to track) we should be safe. I call this out because the compiler side will look identical and it's just additional tracking at runtime on top of the above work to also get reusable command buffers!

benvanik commented 1 year ago

One thing that may need some fiddling is how to communicate the pipeline layout mode or if we want to make it per descriptor set (I think we want to make it per descriptor set) - we can add a bit to iree_hal_descriptor_set_layout_flags_t for whether it's a native descriptor (default) or an indirect one, have the compiler emit the flag when creating such sets, and then have that be queried during dispatch. The flags are set based on the executable target which we'd know needs the flag set. when binding we'd then ignore any indirect descriptor sets as those are covered by the parameter upload path.

benvanik commented 1 year ago

14777 has disabled VMA by default and hopefully it sticks (there may be some issues). We'll let that soak a bit and get into shark before fully removing VMA.

After that #14778 makes the runtime Vulkan HAL detect support for buffer device addresses and enables the feature on the allocations we make.

The next step is to implement the indirect parameter buffer in the Vulkan HAL in preparation for the compiler using it. I've got some sketches that dovetail with indirect command buffers and will see if I can piece them apart for some incremental work.

benvanik commented 1 year ago

Following up from the discussions yesterday, here's the spec I'm going to be shooting for on the compiler/runtime side outside of codegen:

maybe - not sure I like the flag approach, but the below in-memory format is not likely to change

So what would have been:

#version 460
layout(set = 0, binding = 0, std430) buffer set_0_binding_0 { float data[]; };
layout(set = 0, binding = 1, std430) buffer set_0_binding_1 { float data[]; };
layout(set = 0, binding = 2, std430) buffer set_0_binding_2 { float data[]; };
// note no binding 0 used
layout(set = 1, binding = 1, std430) buffer set_1_binding_1 { float data[]; };
void main() {[0];[0];[0];[1] = 1.0f;
  // ...


#version 460
#extension GL_EXT_buffer_reference : require
layout(buffer_reference, std430, buffer_reference_align = 16) buffer binding_f32_t {
  float data[];
layout(set = 3, binding = 0) buffer set_0_t {
  binding_f32_t binding_0;
  binding_f32_t binding_1;
  binding_f32_t binding_2;
} set_0;
layout(set = 3, binding = 1) buffer set_1_t {
  binding_f32_t unused_binding_0;  // note here for alignment
  binding_f32_t binding_1;
} set_1;
void main() {[0];  // access original set(0) binding(0)[0];  // access original set(0) binding(1)[0];  // access original set(0) binding(2)[1] = 1.0f;  // access original set(1) binding(0)
  // ...


; Version: 1.6
; Generator: Khronos Glslang Reference Front End; 11
; Bound: 34
; Schema: 0
               OpCapability Shader
               OpCapability PhysicalStorageBufferAddresses
          %1 = OpExtInstImport "GLSL.std.450"
               OpMemoryModel PhysicalStorageBuffer64 GLSL450
               OpEntryPoint GLCompute %main "main" %set_0 %set_1
               OpExecutionModeId %main LocalSizeId %uint_1 %uint_1 %uint_1
               OpSource GLSL 460
               OpSourceExtension "GL_EXT_buffer_reference"
               OpName %main "main"
               OpName %set_0_t "set_0_t"
               OpMemberName %set_0_t 0 "binding_0"
               OpMemberName %set_0_t 1 "binding_1"
               OpMemberName %set_0_t 2 "binding_2"
               OpName %binding_f32_t "binding_f32_t"
               OpMemberName %binding_f32_t 0 "data"
               OpName %set_0 "set_0"
               OpName %set_1_t "set_1_t"
               OpMemberName %set_1_t 0 "unused_binding_0"
               OpMemberName %set_1_t 1 "binding_1"
               OpName %set_1 "set_1"
               OpMemberDecorate %set_0_t 0 Offset 0
               OpMemberDecorate %set_0_t 1 Offset 8
               OpMemberDecorate %set_0_t 2 Offset 16
               OpDecorate %set_0_t Block
               OpDecorate %_runtimearr_float ArrayStride 4
               OpMemberDecorate %binding_f32_t 0 Offset 0
               OpDecorate %binding_f32_t Block
               OpDecorate %set_0 DescriptorSet 3
               OpDecorate %set_0 Binding 0
               OpMemberDecorate %set_1_t 0 Offset 0
               OpMemberDecorate %set_1_t 1 Offset 8
               OpDecorate %set_1_t Block
               OpDecorate %set_1 DescriptorSet 3
               OpDecorate %set_1 Binding 1
       %void = OpTypeVoid
          %3 = OpTypeFunction %void
       %uint = OpTypeInt 32 0
     %uint_1 = OpConstant %uint 1
               OpTypeForwardPointer %_ptr_PhysicalStorageBuffer_binding_f32_t PhysicalStorageBuffer
    %set_0_t = OpTypeStruct %_ptr_PhysicalStorageBuffer_binding_f32_t %_ptr_PhysicalStorageBuffer_binding_f32_t %_ptr_PhysicalStorageBuffer_binding_f32_t
      %float = OpTypeFloat 32
%_runtimearr_float = OpTypeRuntimeArray %float
%binding_f32_t = OpTypeStruct %_runtimearr_float
%_ptr_PhysicalStorageBuffer_binding_f32_t = OpTypePointer PhysicalStorageBuffer %binding_f32_t
%_ptr_StorageBuffer_set_0_t = OpTypePointer StorageBuffer %set_0_t
      %set_0 = OpVariable %_ptr_StorageBuffer_set_0_t StorageBuffer
        %int = OpTypeInt 32 1
      %int_0 = OpConstant %int 0
%_ptr_StorageBuffer__ptr_PhysicalStorageBuffer_binding_f32_t = OpTypePointer StorageBuffer %_ptr_PhysicalStorageBuffer_binding_f32_t
      %int_1 = OpConstant %int 1
      %int_2 = OpConstant %int 2
    %set_1_t = OpTypeStruct %_ptr_PhysicalStorageBuffer_binding_f32_t %_ptr_PhysicalStorageBuffer_binding_f32_t
%_ptr_StorageBuffer_set_1_t = OpTypePointer StorageBuffer %set_1_t
      %set_1 = OpVariable %_ptr_StorageBuffer_set_1_t StorageBuffer
    %float_1 = OpConstant %float 1
%_ptr_PhysicalStorageBuffer_float = OpTypePointer PhysicalStorageBuffer %float
       %main = OpFunction %void None %3
          %5 = OpLabel
         %18 = OpAccessChain %_ptr_StorageBuffer__ptr_PhysicalStorageBuffer_binding_f32_t %set_0 %int_0
         %19 = OpLoad %_ptr_PhysicalStorageBuffer_binding_f32_t %18
         %21 = OpAccessChain %_ptr_StorageBuffer__ptr_PhysicalStorageBuffer_binding_f32_t %set_0 %int_1
         %22 = OpLoad %_ptr_PhysicalStorageBuffer_binding_f32_t %21
         %24 = OpAccessChain %_ptr_StorageBuffer__ptr_PhysicalStorageBuffer_binding_f32_t %set_0 %int_2
         %25 = OpLoad %_ptr_PhysicalStorageBuffer_binding_f32_t %24
         %29 = OpAccessChain %_ptr_StorageBuffer__ptr_PhysicalStorageBuffer_binding_f32_t %set_1 %int_1
         %30 = OpLoad %_ptr_PhysicalStorageBuffer_binding_f32_t %29
         %33 = OpAccessChain %_ptr_PhysicalStorageBuffer_float %30 %int_0 %int_1
               OpStore %33 %float_1 Aligned 4
benvanik commented 1 year ago

Update that I decided flags are fine for now as this can be experimental. In #14977 I've added the --iree-vulkan-experimental-indirect-bindings=true compiler flag that changes the executable format to that required by the runtime (vulkan-spirv-fb-ptr) and sets the Indirect flag on the descriptor set layouts on the exported executable variant functions.

Next steps on the runtime side are to route iree_hal_command_buffer_push_descriptor_set calls down a special parameter buffer path when the IREE_HAL_DESCRIPTOR_SET_LAYOUT_FLAG_INDIRECT flag is set (along with some other goo), while on the compiler side the codegen lowerings will need to inspect the layout flags and when IREE::HAL::DescriptorSetLayoutFlags::Indirect lower to the above SPIR-V binding style.

benvanik commented 1 year ago

(there's a lot I don't like about this approach but it's not worth me stalling any longer - we've got enough other cleanup around SPIR-V executables and extensions pending and this is such a big switch that it may help ground out discussions on next steps by being so hideous :)

antiagainst commented 11 months ago

Thanks @benvanik for the details! @kuhar will help to flesh out the SPIR-V part:

kuhar commented 7 months ago

I opened a PR with the compiler support:, and landed another one with hal device queries for the related Vulkan extension: There's also a landed MLIR PR for memref to spir-v conversion:

With these three PRs in the tree and 64-bit indexing enabled, the following e2e compiles but fails at runtime:

$ ninja iree-compile ~/iree/iree/tests/e2e/stablehlo_ops/add.mlir --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=rdna3-7900-linux --iree-vulkan-experimental-indirect-bindings=true -o add.vmfb --mlir-disable-threading --mlir-print-ir-after-all 2>add_all.log
$ tools/iree-check-module --module=add.vmfb --device=vulkan://0