Open jungerm2 opened 4 months ago
For the time being, this seems to work, but it's much slower than I'd like:
/// Perform batch matrix multiplication by splitting first dimension
/// if larger than 65535 and recombining results
pub fn bmm<B: Backend>(a: Tensor<B, 3>, b: Tensor<B, 3>) -> Tensor<B, 3> {
let batch_size = 65535;
let [n1, i, j1] = a.dims();
let [n2, j2, k] = b.dims();
assert_eq!(n1, n2);
assert_eq!(j1, j2);
if n1 <= batch_size {
return a.matmul(b);
}
let ranges: Vec<_> = (0..(n1 as u32).div_ceil(batch_size as u32))
.map(|i| (batch_size * i as usize)..(batch_size * (i + 1) as usize).min(n1))
.collect();
let result_parts: Vec<_> = ranges.into_iter().map(|r| {
let a_part = a.clone().slice([r.clone(), 0..i, 0..j1]);
let b_part = b.clone().slice([r.clone(), 0..j2, 0..k]);
a_part.matmul(b_part)
}).collect();
Tensor::cat(result_parts, 0)
}
It could be speed up by into_par_iter
from rayon, but to my great surprise, FloatTensorPrimitive
is Send but not Sync...
EDIT: Fixed some indexing in the above code, now it actually works.
Well, it's a limitation in the launching paradigm used to compute matrix multiplication. Maybe we could use another dimension to handle the batch part.
I wasn't sure if batch matmul was supported, as this seems to be documented nowhere except in pytorch's documentation. It seems to work fine with small tensors but breaks down past a certain size:
I'd expect an output tensor of shape
[500, 500, 4, 6]
, but instead I get the following error:So it seems there's a maximal dimension of 65535 for bmm. I would expect that this backend-specific limitation be abstracted away, i.e the backend should likely batch the bmm and recombine them automatically. Is there a current workaround for this?
I'm using burn 0.13.2 with vulkan version 1.3.283 on fedora 40.