gfx-rs / metal-rs

Rust bindings for Metal
Apache License 2.0
567 stars 112 forks source link

Inconsistent Performance across runs #287

Open jafioti opened 10 months ago

jafioti commented 10 months ago

Hi, I'm trying to write a matrix multiplication kernel, and I'm seeing very inconsistent performance across runs of the kernel with the same input. Here is a reproduction where you can see some runs going much faster than others:

Be sure to run in release mode!

Cargo.toml:

[dependencies]
metal = "0.26.0"
rand = "0.8.5"

Code:

use metal::{
    objc::rc::autoreleasepool, Buffer, BufferRef, CommandBuffer, CommandBufferRef, CommandQueue,
    CompileOptions, ComputePassDescriptor, ComputePassDescriptorRef, ComputePipelineDescriptor,
    ComputePipelineState, CounterSampleBuffer, CounterSampleBufferRef, Device,
    MTLCounterSamplingPoint, MTLResourceOptions, MTLSize, NSRange,
};
use rand::{rngs::StdRng, Rng, SeedableRng};

const NUM_SAMPLES: u64 = 2;

#[inline]
#[allow(clippy::too_many_arguments)]
fn run_naive(
    a_buffer: &Buffer,
    b_buffer: &Buffer,
    c_buffer: &Buffer,
    shader: &ComputePipelineState,
    dev: &Device,
    command_queue: &CommandQueue,
    counter_sample_buffer: &CounterSampleBufferRef,
    mat_size: usize,
) {
    let mut cpu_start = 0;
    let mut gpu_start = 0;
    dev.sample_timestamps(&mut cpu_start, &mut gpu_start);

    let destination_buffer = dev.new_buffer(
        (std::mem::size_of::<u64>() * NUM_SAMPLES as usize) as u64,
        MTLResourceOptions::StorageModeShared,
    );

    let counter_sampling_point = MTLCounterSamplingPoint::AtStageBoundary;
    assert!(dev.supports_counter_sampling(counter_sampling_point));

    let command_buffer = command_queue.new_command_buffer();

    let compute_pass_descriptor = ComputePassDescriptor::new();
    handle_compute_pass_sample_buffer_attachment(compute_pass_descriptor, counter_sample_buffer);

    let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);

    encoder.set_compute_pipeline_state(shader);
    encoder.set_buffer(0, Some(a_buffer), 0);
    encoder.set_buffer(1, Some(b_buffer), 0);
    encoder.set_buffer(2, Some(c_buffer), 0);
    encoder.set_bytes(
        3,
        std::mem::size_of::<u32>() as u64,
        &(mat_size as u32) as *const u32 as *const _,
    );
    encoder.set_bytes(
        4,
        std::mem::size_of::<u32>() as u64,
        &(mat_size as u32) as *const u32 as *const _,
    );
    encoder.set_bytes(
        5,
        std::mem::size_of::<u32>() as u64,
        &(mat_size as u32) as *const u32 as *const _,
    );
    let thread_block_size = 32;
    encoder.set_threadgroup_memory_length(
        0,
        thread_block_size * thread_block_size * 2 * std::mem::size_of::<f32>() as u64,
    );
    encoder.dispatch_thread_groups(
        MTLSize {
            width: (mat_size as u64 + thread_block_size - 1) / thread_block_size,
            height: (mat_size as u64 + thread_block_size - 1) / thread_block_size,
            depth: 1,
        },
        MTLSize {
            width: thread_block_size,
            height: thread_block_size,
            depth: 1,
        },
    );
    encoder.end_encoding();
    resolve_samples_into_buffer(command_buffer, counter_sample_buffer, &destination_buffer);
    command_buffer.commit();
    command_buffer.wait_until_completed();
    let mut cpu_end = 0;
    let mut gpu_end = 0;
    dev.sample_timestamps(&mut cpu_end, &mut gpu_end);
    handle_timestamps(&destination_buffer, cpu_start, cpu_end, gpu_start, gpu_end);
}

fn main() {
    autoreleasepool(|| {
        let mat_size = 8192;
        let iters = 100;
        let mut rng = StdRng::seed_from_u64(0);
        let a_data: Vec<f32> = (0..(mat_size * mat_size))
            .map(|_| rng.gen_range(-0.5..0.5))
            .collect();
        let b_data: Vec<f32> = (0..(mat_size * mat_size))
            .map(|_| rng.gen_range(-0.5..0.5))
            .collect();

        let shader = "#include <metal_stdlib>
            using namespace metal;

            kernel void matmul(
                device float *A [[buffer(0)]],
                device float *B [[buffer(1)]],
                device float *C [[buffer(2)]],
                device uint& M [[buffer(3)]],
                device uint& N [[buffer(4)]],
                device uint& K [[buffer(5)]],
                uint3 tid [[thread_position_in_grid]]
            ) {
                uint row = tid.y;
                uint column = tid.x;

                if(row < M && column < N) {
                    float value = 0.0f;
                    for(int i = 0; i < K; ++i) {
                        value = fma(A[row * K + i], B[i * N + column], value);
                    }
                    C[row * N + column] = value;
                }
            }
            ";
        let dev = Device::system_default().unwrap();
        let c_buffer = dev.new_buffer(
            (mat_size * std::mem::size_of::<f32>()) as u64,
            MTLResourceOptions::StorageModeManaged,
        );

        let a_buffer = dev.new_buffer_with_data(
            unsafe { std::mem::transmute(a_data.as_ptr()) },
            std::mem::size_of_val(&a_data) as u64,
            MTLResourceOptions::StorageModeManaged,
        );
        let b_buffer = dev.new_buffer_with_data(
            unsafe { std::mem::transmute(b_data.as_ptr()) },
            std::mem::size_of_val(&b_data) as u64,
            MTLResourceOptions::StorageModeManaged,
        );
        let counter_sample_buffer = create_counter_sample_buffer(&dev);
        let shader = compile_function("matmul", shader, &dev);
        let command_queue = dev.new_command_queue();
        for _ in 0..iters {
            println!("Naieve");
            run_naive(
                &a_buffer,
                &b_buffer,
                &c_buffer,
                &shader,
                &dev,
                &command_queue,
                &counter_sample_buffer,
                mat_size,
            );
        }
    })
}

fn compile_function(name: &str, code: &str, device: &Device) -> ComputePipelineState {
    let library = device
        .new_library_with_source(code, &CompileOptions::new())
        .unwrap();
    let pipeline_state_descriptor = ComputePipelineDescriptor::new();
    pipeline_state_descriptor
        .set_compute_function(Some(&library.get_function(name, None).unwrap()));
    device
        .new_compute_pipeline_state_with_function(
            pipeline_state_descriptor.compute_function().unwrap(),
        )
        .unwrap()
}

fn handle_compute_pass_sample_buffer_attachment(
    compute_pass_descriptor: &ComputePassDescriptorRef,
    counter_sample_buffer: &CounterSampleBufferRef,
) {
    let sample_buffer_attachment_descriptor = compute_pass_descriptor
        .sample_buffer_attachments()
        .object_at(0)
        .unwrap();

    sample_buffer_attachment_descriptor.set_sample_buffer(counter_sample_buffer);
    sample_buffer_attachment_descriptor.set_start_of_encoder_sample_index(0);
    sample_buffer_attachment_descriptor.set_end_of_encoder_sample_index(1);
}

fn resolve_samples_into_buffer(
    command_buffer: &CommandBufferRef,
    counter_sample_buffer: &CounterSampleBufferRef,
    destination_buffer: &BufferRef,
) {
    let blit_encoder = command_buffer.new_blit_command_encoder();
    blit_encoder.resolve_counters(
        counter_sample_buffer,
        NSRange::new(0_u64, NUM_SAMPLES),
        destination_buffer,
        0_u64,
    );
    blit_encoder.end_encoding();
}

fn handle_timestamps(
    resolved_sample_buffer: &BufferRef,
    cpu_start: u64,
    cpu_end: u64,
    gpu_start: u64,
    gpu_end: u64,
) {
    let samples = unsafe {
        std::slice::from_raw_parts(
            resolved_sample_buffer.contents() as *const u64,
            NUM_SAMPLES as usize,
        )
    };
    let pass_start = samples[0];
    let pass_end = samples[1];

    let cpu_time_span = cpu_end - cpu_start;
    let gpu_time_span = gpu_end - gpu_start;

    let millis = milliseconds_between_begin(pass_start, pass_end, gpu_time_span, cpu_time_span);
    println!("Compute pass duration: {millis} ms");
}

fn milliseconds_between_begin(begin: u64, end: u64, gpu_time_span: u64, cpu_time_span: u64) -> f64 {
    let time_span = (end as f64) - (begin as f64);
    let nanoseconds = time_span / (gpu_time_span as f64) * (cpu_time_span as f64);
    nanoseconds / 1_000_000.0
}

fn create_counter_sample_buffer(device: &Device) -> CounterSampleBuffer {
    let counter_sample_buffer_desc = metal::CounterSampleBufferDescriptor::new();
    counter_sample_buffer_desc.set_storage_mode(metal::MTLStorageMode::Shared);
    counter_sample_buffer_desc.set_sample_count(NUM_SAMPLES);
    let counter_sets = device.counter_sets();

    let timestamp_counter = counter_sets.iter().find(|cs| cs.name() == "timestamp");

    counter_sample_buffer_desc
        .set_counter_set(timestamp_counter.expect("No timestamp counter found"));

    device
        .new_counter_sample_buffer_with_descriptor(&counter_sample_buffer_desc)
        .unwrap()
}
cwfitzgerald commented 10 months ago

You're almost certainly hitting funkyness with regards to the gpu clock speed and power saving. Generally you need to do a large-isg number of runs then take the last one. You can use the system profiler to check the power state of your chip at any one point, making sure it's on "maximum"

jafioti commented 10 months ago

@cwfitzgerald Is there any way to force the program to run on the maximum power mode? I'm hoping to get some sort of reliable performance out of my program

jafioti commented 10 months ago

These performance swings are quite drastic:

Naieve
Compute pass duration: 5455.59 ms
Naieve
Compute pass duration: 6147103160647.369 ms
Naieve
Compute pass duration: 556.000333 ms
Naieve
Compute pass duration: 9777.649291 ms
Naieve
Compute pass duration: 5117.36875 ms
Naieve
Compute pass duration: 3503.177167 ms
Naieve
Compute pass duration: 6147103130642.286 ms
Naieve
Compute pass duration: 10418.711417 ms
Naieve
Compute pass duration: 555.34225 ms
Naieve
Compute pass duration: 6833.369083 ms

I can assume the super high numbers are due to some wrapping issue, but there are still the swings between ~500 ms and ~5000 ms. I feel like it's unlikely that this is caused by the low power mode

cwfitzgerald commented 10 months ago

Run Instruments in Metal System Trace mode, with the target of all processes - you should see the information about the clocks and when exactly thigns are running