gfx-rs / wgpu

A cross-platform, safe, pure-Rust graphics API.
https://wgpu.rs
Apache License 2.0
11.48k stars 855 forks source link

Speed up zero initialization of workgroup memory #4592

Open raphlinus opened 8 months ago

raphlinus commented 8 months ago

This is related to #4591; when forcing spv::ZeroInitializeWorkgroupMemoryMode::Polyfill in device_from_raw(), we observe very slow (but correct!) behavior for zeroing the workgroup shared array - all the work is done on one thread. It would be better to distribute this; in this case the array size and workgroup size match, so for each invocation to zero one array element would be simple and efficient.

zerooooooo.zip

Repro case is the same as the linked bug, but changing line 1307 of vulkan/adapter.rs to Polyfill.

cwfitzgerald commented 8 months ago

For reference, the zero-init code:

SPIRV: https://github.com/gfx-rs/wgpu/blob/trunk/naga/src/back/spv/writer.rs#L1327
MSL: https://github.com/gfx-rs/wgpu/blob/trunk/naga/src/back/msl/writer.rs#L4441-L4549
HLSL: https://github.com/gfx-rs/wgpu/blob/trunk/naga/src/back/hlsl/writer.rs#L1280-L1305
GLSL: https://github.com/gfx-rs/wgpu/blob/trunk/naga/src/back/glsl/mod.rs#L1688-L1718

I think the easiest lift thing to do is that, for top level arrays, use the local index to init that element of the array, masking off the higher invocations than the array length, and doing a compile time loop for arrays longer than the element count.

cwfitzgerald commented 8 months ago

To be clear, I think the init shoudl look like this:

var<workgroup> array1: array<u32, 652>;
var<workgroup> array2: array<u32, 256>;
var<workgroup> array3: array<u32, 45>;
var<workgroup> non_array: u32;

@compute @workgroup_size(16, 16)
fn main(@builtin(local_index) local_index: u32) {
    // All unconditional array init
    // Do loop at compile time, just generate multiple writes for long arrays
    array1[local_index] = <zero init>;
    array1[local_index + 256] = <zero init>;
    array2[local_index] = <zero init>;
    if local_index < 140 {
        // Conditional part of array1 is in conditional
        array1[local_index + 512] = <zero init>;
        if local_index < 45 {
            array3[local_index] = <zero init>;
            if local_index < 1 {
                non_array = <zero init>;
            }
        }
    }
    workgroupBarrier();
}
FL33TW00D commented 5 days ago

This can be closed right? @teoxoy @cwfitzgerald

teoxoy commented 2 days ago

I don't think we implemented a better approach.