tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.54k stars 421 forks source link

Batch matrix multiply leads to vulkan error on WGPU #1865

Open jungerm2 opened 4 months ago

jungerm2 commented 4 months ago

I wasn't sure if batch matmul was supported, as this seems to be documented nowhere except in pytorch's documentation. It seems to work fine with small tensors but breaks down past a certain size:

type B = burn::backend::wgpu::JitBackend<WgpuRuntime<AutoGraphicsApi, f32, i32>>;
let a: Tensor<B, 4> = Tensor::random([500, 500, 4, 5], Distribution::Normal(-1.0, 1.0), &Default::default());
let b: Tensor<B, 4> = Tensor::random([500, 500, 5, 6], Distribution::Normal(-1.0, 1.0), &Default::default());
let out = a.matmul(b);
println!("{:?}", out);

I'd expect an output tensor of shape [500, 500, 4, 6], but instead I get the following error:

wgpu error: Validation Error

Caused by:
    In a ComputePass
      note: encoder = `<CommandBuffer-(1, 1, Vulkan)>`
    In a dispatch command, indirect:false
      note: compute pipeline = `<ComputePipeline-(4, 1, Vulkan)>`
    Each current dispatch group size dimension ([1, 1, 250000]) must be less or equal to 65535

So it seems there's a maximal dimension of 65535 for bmm. I would expect that this backend-specific limitation be abstracted away, i.e the backend should likely batch the bmm and recombine them automatically. Is there a current workaround for this?

I'm using burn 0.13.2 with vulkan version 1.3.283 on fedora 40.

jungerm2 commented 4 months ago

For the time being, this seems to work, but it's much slower than I'd like:

/// Perform batch matrix multiplication by splitting first dimension
/// if larger than 65535 and recombining results
pub fn bmm<B: Backend>(a: Tensor<B, 3>, b: Tensor<B, 3>) -> Tensor<B, 3> {
    let batch_size = 65535;
    let [n1, i, j1] = a.dims();
    let [n2, j2, k] = b.dims();
    assert_eq!(n1, n2); 
    assert_eq!(j1, j2);

    if n1 <= batch_size {
        return a.matmul(b); 
    }

    let ranges: Vec<_> = (0..(n1 as u32).div_ceil(batch_size as u32))
        .map(|i| (batch_size * i as usize)..(batch_size * (i + 1) as usize).min(n1))
        .collect();
    let result_parts: Vec<_> = ranges.into_iter().map(|r| {
        let a_part = a.clone().slice([r.clone(), 0..i, 0..j1]);
        let b_part = b.clone().slice([r.clone(), 0..j2, 0..k]);
        a_part.matmul(b_part)
    }).collect();
    Tensor::cat(result_parts, 0)
}

It could be speed up by into_par_iter from rayon, but to my great surprise, FloatTensorPrimitive is Send but not Sync...

EDIT: Fixed some indexing in the above code, now it actually works.

nathanielsimard commented 3 months ago

Well, it's a limitation in the launching paradigm used to compute matrix multiplication. Maybe we could use another dimension to handle the batch part.