gfx-rs / metal-rs

Rust bindings for Metal
Apache License 2.0
567 stars 112 forks source link

How does one launch a 2D grid of threads using dispatch_threads? #301

Closed nihalpasham closed 8 months ago

nihalpasham commented 8 months ago

As a discussions tab doesn't exist, I'm hoping this is the right place for this question.

Context: I have been trying to figure out how one would go about launching a 2D grid of threads. I put together a basic kernel that computes the dotprod of 2 arrays (length of 1 million elements). The results seem (to be a bit) confusing

When I launch a 1D grid with a width of 1 million, height of 1 and depth of 1, I get the expected result.

let thread_execution_width = compute_pipeline.thread_execution_width();
            let threads_per_threadgroup = compute_pipeline.max_total_threads_per_threadgroup();

            println!("  simd_width: {:?}", thread_execution_width);
            println!("  threads_per_threadgroup: {:?}", threads_per_threadgroup);

            let grid_size = MTLSize::new(1_000_000 as u64, 1, 1);
            let threadgroup_size = MTLSize::new(
                thread_execution_width, // 32
                threads_per_threadgroup / thread_execution_width,  // 1024
                1,
            );
            compute_encoder.dispatch_threads(grid_size, threadgroup_size);

But when I use a 2D grid with width 1000, height 1000 and depth 1, I get a partial result i.e. it computes the dotprod for the first 1000 elements only.

let thread_execution_width = compute_pipeline.thread_execution_width();
            let threads_per_threadgroup = compute_pipeline.max_total_threads_per_threadgroup();

            println!("  simd_width: {:?}", thread_execution_width);
            println!("  threads_per_threadgroup: {:?}", threads_per_threadgroup);

            let grid_size = MTLSize::new(1000 as u64, 1000 as u64, 1);
            let threadgroup_size = MTLSize::new(
                thread_execution_width,
                threads_per_threadgroup / thread_execution_width, 
                1,
            );
            compute_encoder.dispatch_threads(grid_size, threadgroup_size);

compute shader is below

kernel void dotprod(constant uint *arrayA,
                    constant uint *arrayB,
                    device uint *result,
                    uint2 grid_size [[grid_size]],
                    uint2 thread_idx [[thread_position_in_grid]])
{
    uint n = grid_size.x; 
    result[thread_idx.y * n + thread_idx.x] = 
        arrayA[thread_idx.y * n + thread_idx.x] * arrayB[thread_idx.y * n + thread_idx.x];

}

I believe the [[grid_size]] is (0,0) in both cases and actually has no effect i.e. the result simply evaluates to

result[thread_idx.x] =  arrayA[thread_idx.x] * arrayB[thread_idx.x];

What am I missing? Any help would be great.

nihalpasham commented 8 months ago

I figured this out. We are suppose to use the [[threads_per_grid]] attribute instead of [[grid_size]] to retrieve grid dimensions.

Leaving this here for completeness-sake, in-case anyone finds it useful in the future


#include <metal_stdlib>
using namespace metal;

// Using a 1d thread grid - this is fastest kernel for dotprod (assuming you use 1d threadgroup size of width // 32)
//
kernel void dotprod(constant uint *arrayA [[buffer(0)]],
                    constant uint *arrayB [[buffer(1)]],
                    device uint *result [[buffer(2)]],
                    uint pos [[thread_position_in_grid]])
{
    result[pos] = arrayA[pos] * arrayB[pos];
}

// use this kernel with a 2d grid of threads - comparing gpu vs cpu (parallel), the gpu gets slower as we 
// increase the number of elements
// when using this kernel with arrays of 10_000_000 elements, this should return all zeroes as
// n == 10_000
//
kernel void dotprod(constant uint *arrayA [[buffer(0)]],
                    constant uint *arrayB [[buffer(1)]],
                    device uint *result [[buffer(2)]],
                    uint2 grid_size [[threads_per_grid]],
                    uint2 pos [[thread_position_in_grid]])
{
    // get grid width
    uint n = grid_size.x;
    if (n == 10000) {return;}
    result[pos.y * n + pos.x] = arrayA[pos.y * n + pos.x] * arrayB[pos.y * n + pos.x];
}

// To launch a 2d grid of threads, use this.  To get the grid_size, use attribute [[threads_per_grid]] not 
// [[grid_size]]
// this only works with dispatch_thread_groups()
//
kernel void dotprod(constant uint *arrayA,
                    constant uint *arrayB,
                    device uint *result,
                    uint2 grid_size [[threads_per_grid]],
                    uint2 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
                    uint2 thread_position_in_threadgroup [[thread_position_in_threadgroup]],
                    uint2 threads_per_threadgroup [[threads_per_threadgroup]])
{
    // calculate thread position in grid
    uint2 pos =
        (threadgroup_position_in_grid * threads_per_threadgroup) +
        thread_position_in_threadgroup;
    // calculate grid size
    uint n = grid_size.x;
    result[pos.y * n + pos.x] = arrayA[pos.y * n + pos.x] * arrayB[pos.y * n + pos.x];
}