Open rybakov opened 4 months ago
One should be comparing torch.mm
and torch._scaled_mm
, which fails with a similar error:
>>> x = torch.rand(16, 32, 128)
>>> w = torch.rand(16, 128, 64)
>>> z = torch.mm(x, w)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: self must be a matrix
>>> z = torch.bmm(x, w)
But asking for a bmm version of scaled_mm stands to reason
One should be comparing
torch.mm
andtorch._scaled_mm
, which fails with a similar error:>>> x = torch.rand(16, 32, 128) >>> w = torch.rand(16, 128, 64) >>> z = torch.mm(x, w) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: self must be a matrix >>> z = torch.bmm(x, w)
But asking for a bmm version of scaled_mm stands to reason
I'd second this, was actually looking for a bmm version of _scaled_mm
🚀 The feature, motivation and pitch
3D fp8 matrix multiplication can be useful for fp8 model with 3D matmul (it also can be used to improve accuracy of models with 2D fp8 quantized matrix multiplication). _scaled_mm is designed for fp8 matrix multiplication, but currently _scaled_mm supports only 2D x 2D matrix multiplication, as shown below. It would be great to extend _scaled_mm with support of input 3D matrices.
Alternatives
No response
Additional context
No response
cc @yanbing-j @vkuzo @albanD @kadeng @penguinwu