pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.41k stars 22.73k forks source link

3D matrix support in _scaled_mm #129803

Open rybakov opened 4 months ago

rybakov commented 4 months ago

🚀 The feature, motivation and pitch

3D fp8 matrix multiplication can be useful for fp8 model with 3D matmul (it also can be used to improve accuracy of models with 2D fp8 quantized matrix multiplication). _scaled_mm is designed for fp8 matrix multiplication, but currently _scaled_mm supports only 2D x 2D matrix multiplication, as shown below. It would be great to extend _scaled_mm with support of input 3D matrices.

device = 'cuda'
# does not work
x_size = (16, 32, 128)
w_size = (16, 64, 128)

# works ok
# x_size = (32, 128)
# w_size = (64, 128)
dtype = torch.float32
x = torch.rand(x_size, dtype=torch.float32, device=device).to(torch.float8_e4m3fn)
w = torch.rand(w_size, dtype=torch.float32, device=device).to(torch.float8_e4m3fn)
one = torch.tensor([1.0], dtype=torch.float32, device=x.device)

# for reference
y_f32 = torch.matmul(x.to(torch.float32), w.transpose(-1, -2).to(torch.float32))

y, _ = torch._scaled_mm(
    x,
    w.transpose(-1, -2),
    scale_a=one,
    scale_b=one,
    out_dtype=dtype,
    use_fast_accum=False,
)

Alternatives

No response

Additional context

No response

cc @yanbing-j @vkuzo @albanD @kadeng @penguinwu

malfet commented 4 months ago

One should be comparing torch.mm and torch._scaled_mm, which fails with a similar error:

>>> x = torch.rand(16, 32, 128)
>>> w = torch.rand(16, 128, 64)
>>> z = torch.mm(x, w)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: self must be a matrix
>>> z = torch.bmm(x, w)

But asking for a bmm version of scaled_mm stands to reason

Xynonners commented 3 months ago

One should be comparing torch.mm and torch._scaled_mm, which fails with a similar error:

>>> x = torch.rand(16, 32, 128)
>>> w = torch.rand(16, 128, 64)
>>> z = torch.mm(x, w)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: self must be a matrix
>>> z = torch.bmm(x, w)

But asking for a bmm version of scaled_mm stands to reason

I'd second this, was actually looking for a bmm version of _scaled_mm