tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
Apache License 2.0
8.54k stars 421 forks source link

Batch matrix multiply leads to vulkan error on WGPU #1865

Open jungerm2 opened 4 months ago

jungerm2 commented 4 months ago

I wasn't sure if batch matmul was supported, as this seems to be documented nowhere except in pytorch's documentation. It seems to work fine with small tensors but breaks down past a certain size:

type B = burn::backend::wgpu::JitBackend<WgpuRuntime<AutoGraphicsApi, f32, i32>>;
let a: Tensor<B, 4> = Tensor::random([500, 500, 4, 5], Distribution::Normal(-1.0, 1.0), &Default::default());
let b: Tensor<B, 4> = Tensor::random([500, 500, 5, 6], Distribution::Normal(-1.0, 1.0), &Default::default());
let out = a.matmul(b);
println!("{:?}", out);

I'd expect an output tensor of shape [500, 500, 4, 6], but instead I get the following error:

wgpu error: Validation Error

Caused by:
    In a ComputePass
      note: encoder = `<CommandBuffer-(1, 1, Vulkan)>`
    In a dispatch command, indirect:false
      note: compute pipeline = `<ComputePipeline-(4, 1, Vulkan)>`
    Each current dispatch group size dimension ([1, 1, 250000]) must be less or equal to 65535

So it seems there's a maximal dimension of 65535 for bmm. I would expect that this backend-specific limitation be abstracted away, i.e the backend should likely batch the bmm and recombine them automatically. Is there a current workaround for this?

I'm using burn 0.13.2 with vulkan version 1.3.283 on fedora 40.

jungerm2 commented 4 months ago

For the time being, this seems to work, but it's much slower than I'd like:

/// Perform batch matrix multiplication by splitting first dimension
/// if larger than 65535 and recombining results
pub fn bmm<B: Backend>(a: Tensor<B, 3>, b: Tensor<B, 3>) -> Tensor<B, 3> {
    let batch_size = 65535;
    let [n1, i, j1] = a.dims();
    let [n2, j2, k] = b.dims();
    assert_eq!(n1, n2); 
    assert_eq!(j1, j2);

    if n1 <= batch_size {
        return a.matmul(b); 

    let ranges: Vec<_> = (0..(n1 as u32).div_ceil(batch_size as u32))
        .map(|i| (batch_size * i as usize)..(batch_size * (i + 1) as usize).min(n1))
    let result_parts: Vec<_> = ranges.into_iter().map(|r| {
        let a_part = a.clone().slice([r.clone(), 0..i, 0..j1]);
        let b_part = b.clone().slice([r.clone(), 0..j2, 0..k]);
    Tensor::cat(result_parts, 0)

It could be speed up by into_par_iter from rayon, but to my great surprise, FloatTensorPrimitive is Send but not Sync...

EDIT: Fixed some indexing in the above code, now it actually works.

nathanielsimard commented 3 months ago

Well, it's a limitation in the launching paradigm used to compute matrix multiplication. Maybe we could use another dimension to handle the batch part.