I want to plan for how we are going to add scaling granularities in the Python layer of float8 code. Today, we only have per-tensor scaling which is transposeable. For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd:
a. keep the bf16 copy to be able to rescale across dim0 and dim1
b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision)
c. keep some of the gemms in bf16 to avoid the need to scale twice
The modeling logic in Float8Linear for a/b would look like:
def forward(self, x):
if scaling_type == TENSORWISE:
x_maybe_fp8 = to_fp8_tensorwise(x, ...)
elif scaling_type == ROWWISE:
x_maybe_fp8 = to_fp8_rowwise(x, dim=0, ...)
# repeat for w
y = float8_mm_op(x_maybe_fp8, w_maybe_fp8, ...)
And, there are at least two choices I see for float8_mm_op:
# Option 1 (current code without this PR): use the torch.mm override
@implements([aten.mm.default, aten.matmul.default])
def float8_mm(aten_op, args, kwargs=None):
...
# Option 2 (this PR): use torch.autograd.Function
class float8_mm(torch.autograd.Function):
...
To support future scaling granularities, whichever choice we go with will have to do something like below:
def float8_mm(x_maybe_fp8, w_maybe_fp8):
if isinstance(x_maybe_fp8, Float8Tensor):
x_fp8 = x_maybe_fp8
else:
x_fp8 = to_fp8(x_maybe_fp8, scaling_granularity, ...)
# repeat for w
# call torch._scaled_mm
Furthermore, to keep things readable / debuggable, it would be good to:
be able to print tensors before/after quantization
be able to associate tensors to their parent module, and the specific gemm in fwd/bwd in that module
To do the above, we'll need to pass around metadata such as module FQNs.
This PR implements Option 2 as IMO this is more readable/debuggable.
Stack from ghstack (oldest at bottom):
Summary:
I want to plan for how we are going to add scaling granularities in the Python layer of float8 code. Today, we only have per-tensor scaling which is transposeable. For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd: a. keep the bf16 copy to be able to rescale across dim0 and dim1 b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision) c. keep some of the gemms in bf16 to avoid the need to scale twice
The modeling logic in Float8Linear for a/b would look like:
And, there are at least two choices I see for
float8_mm_op
:To support future scaling granularities, whichever choice we go with will have to do something like below:
Furthermore, to keep things readable / debuggable, it would be good to:
To do the above, we'll need to pass around metadata such as module FQNs.
This PR implements Option 2 as IMO this is more readable/debuggable.
Test plan: