gfx-rs / metal-rs

Rust bindings for Metal
Apache License 2.0
567 stars 112 forks source link

Discrepency between GPU timing and CPU timing #288

Open jafioti opened 10 months ago

jafioti commented 10 months ago

Hi, I am testing out the timing of different compute kernels. I use the same timing method as in the compute example to get the compute pass time. I also do a simple std::time::Instant timer from before the .commit() to after the .wait_until_completed() and my CPU timer ends up being around 12x longer than the CPU timer. There really shouldn't be any copying between CPU and GPU here, so the only thing I can think of is waiting to dispatch the kernel, but I can't imagine it takes 14ms!

Here is my entire reproducable example:

use metal::{
    objc::rc::autoreleasepool, Buffer, BufferRef, CommandBufferRef, CompileOptions,
    ComputeCommandEncoderRef, ComputePassDescriptor, ComputePassDescriptorRef,
    ComputePipelineDescriptor, ComputePipelineState, CounterSampleBuffer, CounterSampleBufferRef,
    Device, MTLCommandBufferStatus, MTLResourceOptions, MTLSize, NSRange,
};
use rand::{rngs::StdRng, Rng, SeedableRng};

const NUM_SAMPLES: u64 = 10;

const NAIEVE_SHADER: &str = "
#include <metal_stdlib>
using namespace metal;

kernel void matmul(
    device float *A [[buffer(0)]],
    device float *B [[buffer(1)]],
    device float *C [[buffer(2)]],
    device uint& M [[buffer(3)]],
    device uint& N [[buffer(4)]],
    device uint& K [[buffer(5)]],
    threadgroup float* shared_memory [[threadgroup(0)]],
    uint3 global_pos [[thread_position_in_grid]],
    uint3 local_pos [[thread_position_in_threadgroup]],
    uint3 block_pos [[threadgroup_position_in_grid]],
    uint3 block_size[[threads_per_threadgroup]]
) {
    if (global_pos.x < N && global_pos.y < M) {
        float value = 0.0f;
        for(int i = 0; i < K; ++i) {
            value = fast::fma(A[global_pos.y * K + i], B[i * N + global_pos.x], value);
        }
        C[global_pos.y * N + global_pos.x] = value;
    }
}";

const TILED_SHADER: &str = "
#include <metal_stdlib>
using namespace metal;

kernel void tiled_matmul(
    device float *A [[buffer(0)]],
    device float *B [[buffer(1)]],
    device float *C [[buffer(2)]],
    device uint& M [[buffer(3)]],
    device uint& N [[buffer(4)]],
    device uint& K [[buffer(5)]],
    threadgroup float* shared_memory [[threadgroup(0)]],
    uint3 global_pos [[thread_position_in_grid]],
    uint3 local_pos [[thread_position_in_threadgroup]],
    uint3 block_pos [[threadgroup_position_in_grid]],
    uint3 block_size[[threads_per_threadgroup]]
) {
    float sum = 0.0f;

    if (global_pos.y >= M || global_pos.x >= N) return;

    for (int m = 0; m < (K + block_size.x - 1) / block_size.x; ++m) {
        if (m * block_size.x + local_pos.x < K) {
            shared_memory[local_pos.y * block_size.x + local_pos.x] = A[global_pos.y * K + m * block_size.x + local_pos.x];
        } else {
            shared_memory[local_pos.y * block_size.x + local_pos.x] = 0.0f;
        }

        if (m * block_size.y + local_pos.y < K) {
            shared_memory[(block_size.y + local_pos.y) * block_size.x + local_pos.x] = B[(m * block_size.y + local_pos.y) * N + global_pos.x];
        } else {
            shared_memory[(block_size.y + local_pos.y) * block_size.x + local_pos.x] = 0.0f;
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        for (int e = 0; e < block_size.x; ++e) {
            sum = fast::fma(shared_memory[local_pos.y * block_size.x + e], shared_memory[(block_size.y + e) * block_size.x + local_pos.x], sum);
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    C[global_pos.y * N + (global_pos.x)] = sum;
}";

const PREFETCH_SHADER: &str = "
#include <metal_stdlib>
using namespace metal;

kernel void prefetch_matmul(
    device float *A [[buffer(0)]],
    device float *B [[buffer(1)]],
    device float *C [[buffer(2)]],
    device uint& M [[buffer(3)]],
    device uint& N [[buffer(4)]],
    device uint& K [[buffer(5)]],
    threadgroup float* shared_memory [[threadgroup(0)]],
    uint3 global_pos [[thread_position_in_grid]],
    uint3 local_pos [[thread_position_in_threadgroup]],
    uint3 block_pos [[threadgroup_position_in_grid]],
    uint3 block_size[[threads_per_threadgroup]]
) {
    if (global_pos.y >= M || global_pos.x >= N) return;

    float sum = 0.0f;

    threadgroup float* tile0 = shared_memory;
    threadgroup float* tile1 = shared_memory + block_size.x * block_size.y * 2;

    if (local_pos.x < K) {
        tile0[local_pos.y * block_size.x + local_pos.x] = A[global_pos.y * K + local_pos.x];
    } else {
        tile0[local_pos.y * block_size.x + local_pos.x] = 0.0f;
    }

    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (int m = 1; m < (K + block_size.x - 1) / block_size.x; ++m) {
        if (m * block_size.x + local_pos.x < K) {
            tile1[local_pos.y * block_size.x + local_pos.x] = A[global_pos.y * K + m * block_size.x + local_pos.x];
        } else {
            tile1[local_pos.y * block_size.x + local_pos.x] = 0.0f;
        }

        for (int e = 0; e < block_size.x; ++e) {
            sum = fast::fma(tile0[local_pos.y * block_size.x + e], B[(m - 1) * block_size.y * N + e * N + global_pos.x], sum);
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        threadgroup float* temp = tile0;
        tile0 = tile1;
        tile1 = temp;
    }

    C[global_pos.y * N + global_pos.x] = sum;
}";

fn run(
    a_buffer: &Buffer,
    b_buffer: &Buffer,
    shader: &ComputePipelineState,
    dev: &Device,
    mat_size: usize,
) -> Option<(Vec<f32>, f32, f32)> {
    autoreleasepool(|| {
        let mut cpu_start = 0;
        let mut gpu_start = 0;
        dev.sample_timestamps(&mut cpu_start, &mut gpu_start);

        let counter_sample_buffer = create_counter_sample_buffer(dev);
        let destination_buffer = dev.new_buffer(
            (std::mem::size_of::<u64>() * NUM_SAMPLES as usize) as u64,
            MTLResourceOptions::StorageModeManaged,
        );

        let c_buffer = dev.new_buffer(
            (mat_size * mat_size * std::mem::size_of::<f32>()) as u64,
            MTLResourceOptions::StorageModeManaged,
        );
        let command_queue = dev.new_command_queue();
        let command_buffer = command_queue.new_command_buffer();

        let compute_pass_descriptor = ComputePassDescriptor::new();
        handle_compute_pass_sample_buffer_attachment(
            compute_pass_descriptor,
            &counter_sample_buffer,
        );

        let encoder =
            command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);

        encoder.set_compute_pipeline_state(shader);
        encoder.set_buffer(0, Some(a_buffer), 0);
        encoder.set_buffer(1, Some(b_buffer), 0);
        encoder.set_buffer(2, Some(&c_buffer), 0);
        set_input_u32(encoder, 3, mat_size as u64);
        set_input_u32(encoder, 4, mat_size as u64);
        set_input_u32(encoder, 5, mat_size as u64);
        let thread_block_size = 32;
        encoder.set_threadgroup_memory_length(
            0,
            thread_block_size * thread_block_size * std::mem::size_of::<f32>() as u64,
        );
        encoder.dispatch_thread_groups(
            MTLSize {
                width: (mat_size as u64 + thread_block_size - 1) / thread_block_size,
                height: (mat_size as u64 + thread_block_size - 1) / thread_block_size,
                depth: 1,
            },
            MTLSize {
                width: thread_block_size,
                height: thread_block_size,
                depth: 1,
            },
        );
        encoder.end_encoding();
        resolve_samples_into_buffer(command_buffer, &counter_sample_buffer, &destination_buffer);
        let now = std::time::Instant::now();
        command_buffer.commit();
        command_buffer.wait_until_completed();
        let millis = now.elapsed().as_millis();
        let mut cpu_end = 0;
        let mut gpu_end = 0;
        dev.sample_timestamps(&mut cpu_end, &mut gpu_end);
        match command_buffer.status() {
            MTLCommandBufferStatus::Completed => Some((
                copy_from_buffer(&c_buffer),
                handle_timestamps(&destination_buffer, cpu_start, cpu_end, gpu_start, gpu_end),
                millis as f32,
            )),
            _ => None,
        }
    })
}

fn main() {
    autoreleasepool(|| {
        let mat_size = 8192;
        let iters = 100;
        let mut rng = StdRng::seed_from_u64(0);
        let a_data: Vec<f32> = (0..(mat_size * mat_size))
            .map(|_| rng.gen_range(-0.5..0.5))
            .collect();
        let b_data: Vec<f32> = (0..(mat_size * mat_size))
            .map(|_| rng.gen_range(-0.5..0.5))
            .collect();

        let dev = Device::system_default().unwrap();
        let a_buffer = copy_to_buffer(&a_data, &dev);
        let b_buffer = copy_to_buffer(&b_data, &dev);

        let shader = compile_function("matmul", NAIEVE_SHADER, &dev);
        let mut data: Option<Vec<f32>> = None;
        let mut successes = 0;
        let mut total_gpu_time = 0.0;
        let mut total_cpu_time = 0.0;
        for _ in 0..iters {
            let curr_data = run(&a_buffer, &b_buffer, &shader, &dev, mat_size);
            if let Some((curr_data, gpu_time, cpu_time)) = curr_data {
                successes += 1;
                total_gpu_time += gpu_time;
                total_cpu_time += cpu_time;
                match &mut data {
                    Some(d) => {
                        for (i, (a, b)) in d.iter().zip(curr_data.iter()).enumerate() {
                            if (*a - *b).abs() > 1e-5 {
                                println!("Index {i} A: {a} B: {b}");
                            }
                        }
                    }
                    None => {
                        data = Some(curr_data);
                    }
                }
            }
        }
        println!(
            "Naive    CPU: {}ms GPU: {}ms",
            total_cpu_time / successes as f32,
            total_gpu_time / successes as f32
        );

        let shader = compile_function("tiled_matmul", TILED_SHADER, &dev);
        let mut successes = 0;
        let mut total_gpu_time = 0.0;
        let mut total_cpu_time = 0.0;
        for _ in 0..iters {
            let curr_data = run(&a_buffer, &b_buffer, &shader, &dev, mat_size);
            if let Some((curr_data, gpu_time, cpu_time)) = curr_data {
                successes += 1;
                total_gpu_time += gpu_time;
                total_cpu_time += cpu_time;
                match &mut data {
                    Some(d) => {
                        for (i, (a, b)) in d.iter().zip(curr_data.iter()).enumerate() {
                            if (*a - *b).abs() > 1e-5 {
                                println!("Index {i} A: {a} B: {b}");
                            }
                        }
                    }
                    None => {
                        data = Some(curr_data);
                    }
                }
            }
        }

        println!(
            "Tiled    CPU: {}ms GPU: {}ms",
            total_cpu_time / successes as f32,
            total_gpu_time / successes as f32
        );

        let shader = compile_function("prefetch_matmul", PREFETCH_SHADER, &dev);
        let mut successes = 0;
        let mut total_gpu_time = 0.0;
        let mut total_cpu_time = 0.0;
        for _ in 0..iters {
            let curr_data = run(&a_buffer, &b_buffer, &shader, &dev, mat_size);
            if let Some((curr_data, gpu_time, cpu_time)) = curr_data {
                successes += 1;
                total_gpu_time += gpu_time;
                total_cpu_time += cpu_time;
                match &mut data {
                    Some(d) => {
                        for (i, (a, b)) in d.iter().zip(curr_data.iter()).enumerate() {
                            if (*a - *b).abs() > 1e-5 {
                                println!("Index {i} A: {a} B: {b}");
                            }
                        }
                    }
                    None => {
                        data = Some(curr_data);
                    }
                }
            }
        }

        println!(
            "Prefetch CPU: {}ms GPU: {}ms",
            total_cpu_time / successes as f32,
            total_gpu_time / successes as f32
        );
    })
}

fn set_input_u32(encoder: &ComputeCommandEncoderRef, num: u32, index: u64) {
    encoder.set_bytes(
        index,
        std::mem::size_of::<u32>() as u64,
        &(num) as *const u32 as *const _,
    );
}

fn copy_to_buffer(v: &[f32], dev: &Device) -> Buffer {
    dev.new_buffer_with_data(
        unsafe { std::mem::transmute(v.as_ptr()) },
        std::mem::size_of_val(v) as u64,
        MTLResourceOptions::StorageModeManaged,
    )
}

fn copy_from_buffer(buffer: &Buffer) -> Vec<f32> {
    let mut data = vec![0.0; buffer.length() as usize / std::mem::size_of::<f32>()];
    let ptr = buffer.contents() as *mut f32;
    for (i, d) in data.iter_mut().enumerate() {
        *d = unsafe { *ptr.add(i) };
    }
    data
}

fn compile_function(name: &str, code: &str, device: &Device) -> ComputePipelineState {
    let library = device
        .new_library_with_source(code, &CompileOptions::new())
        .unwrap();
    let pipeline_state_descriptor = ComputePipelineDescriptor::new();
    pipeline_state_descriptor
        .set_compute_function(Some(&library.get_function(name, None).unwrap()));
    device
        .new_compute_pipeline_state_with_function(
            pipeline_state_descriptor.compute_function().unwrap(),
        )
        .unwrap()
}

fn handle_compute_pass_sample_buffer_attachment(
    compute_pass_descriptor: &ComputePassDescriptorRef,
    counter_sample_buffer: &CounterSampleBufferRef,
) {
    let sample_buffer_attachment_descriptor = compute_pass_descriptor
        .sample_buffer_attachments()
        .object_at(0)
        .unwrap();

    sample_buffer_attachment_descriptor.set_sample_buffer(counter_sample_buffer);
    sample_buffer_attachment_descriptor.set_start_of_encoder_sample_index(0);
    sample_buffer_attachment_descriptor.set_end_of_encoder_sample_index(1);
}

fn resolve_samples_into_buffer(
    command_buffer: &CommandBufferRef,
    counter_sample_buffer: &CounterSampleBufferRef,
    destination_buffer: &BufferRef,
) {
    let blit_encoder = command_buffer.new_blit_command_encoder();
    blit_encoder.resolve_counters(
        counter_sample_buffer,
        NSRange::new(0_u64, NUM_SAMPLES),
        destination_buffer,
        0_u64,
    );
    blit_encoder.end_encoding();
}

fn handle_timestamps(
    resolved_sample_buffer: &BufferRef,
    cpu_start: u64,
    cpu_end: u64,
    gpu_start: u64,
    gpu_end: u64,
) -> f32 {
    let samples = unsafe {
        std::slice::from_raw_parts(
            resolved_sample_buffer.contents() as *const u64,
            NUM_SAMPLES as usize,
        )
    };
    let pass_start = samples[0];
    let pass_end = samples[1];

    let cpu_time_span = cpu_end - cpu_start;
    let gpu_time_span = gpu_end - gpu_start;

    let millis = milliseconds_between_begin(pass_start, pass_end, gpu_time_span, cpu_time_span);
    millis as f32
}

fn milliseconds_between_begin(begin: u64, end: u64, gpu_time_span: u64, cpu_time_span: u64) -> f64 {
    let time_span = (end as f64) - (begin as f64);
    let nanoseconds = time_span / (gpu_time_span as f64) * (cpu_time_span as f64);
    nanoseconds / 1_000_000.0
}

fn create_counter_sample_buffer(device: &Device) -> CounterSampleBuffer {
    let counter_sample_buffer_desc = metal::CounterSampleBufferDescriptor::new();
    counter_sample_buffer_desc.set_storage_mode(metal::MTLStorageMode::Shared);
    counter_sample_buffer_desc.set_sample_count(NUM_SAMPLES);
    let counter_sets = device.counter_sets();

    let timestamp_counter = counter_sets.iter().find(|cs| cs.name() == "timestamp");

    counter_sample_buffer_desc
        .set_counter_set(timestamp_counter.expect("No timestamp counter found"));

    device
        .new_counter_sample_buffer_with_descriptor(&counter_sample_buffer_desc)
        .unwrap()
}

And my output:

Naive    CPU: 17.67ms GPU: 1.3157867ms
Tiled    CPU: 14.56ms GPU: 1.3267895ms
Prefetch CPU: 14.65ms GPU: 1.2971923ms
grovesNL commented 10 months ago

GPU timestamps with sampleTimestamps are complicated in Metal. They're not actually nanoseconds so it's up to the application to try to correlation GPU timing with CPU timing somehow and get back to some unit of time that makes sense.

More background here: https://feresignum.com/resolving-metal-gpu-timers/ https://github.com/gpuweb/gpuweb/issues/1325

Note this crate is just provides Rust bindings to the Metal API, so this is how it works in Metal in general.

jafioti commented 10 months ago

@grovesNL Thanks for the links. Is it normal to take 6ms to complete a kernel that does nothing? I think I might not be setting something up correctly. When I comment out all the code in my kernel, it still takes 6ms to run.

grovesNL commented 10 months ago

It's hard to say what's going on here for sure without profiling, I'd look at what's happening in a profiler like Metal System Trace or the Xcode profiler. You could try moving all of the buffer creation/copies/other setup outside of run to see if you can tell where the overhead is coming from.

There is definitely non-zero overhead to a GPU dispatch followed by a readback in general (e.g., a lot of programs try to avoid waiting on a GPU read, recognizing all GPU work as completing asynchronously instead), but hard for me to guess what that overhead might be on your system without eliminating everything else here.