pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
211 stars 20 forks source link

bring back torch.autograd.Function #316

Open vkuzo opened 4 months ago

vkuzo commented 4 months ago

Stack from ghstack (oldest at bottom):

Summary:

I want to plan for how we are going to add scaling granularities in the Python layer of float8 code. Today, we only have per-tensor scaling which is transposeable. For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd: a. keep the bf16 copy to be able to rescale across dim0 and dim1 b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision) c. keep some of the gemms in bf16 to avoid the need to scale twice

The modeling logic in Float8Linear for a/b would look like:

def forward(self, x):
    if scaling_type == TENSORWISE:
        x_maybe_fp8 = to_fp8_tensorwise(x, ...)
    elif scaling_type == ROWWISE:
        x_maybe_fp8 = to_fp8_rowwise(x, dim=0, ...)
    # repeat for w

    y = float8_mm_op(x_maybe_fp8, w_maybe_fp8, ...)

And, there are at least two choices I see for float8_mm_op:

# Option 1 (current code without this PR): use the torch.mm override
@implements([aten.mm.default, aten.matmul.default])
def float8_mm(aten_op, args, kwargs=None):
    ...

# Option 2 (this PR): use torch.autograd.Function
class float8_mm(torch.autograd.Function):
    ...

To support future scaling granularities, whichever choice we go with will have to do something like below:

def float8_mm(x_maybe_fp8, w_maybe_fp8):
    if isinstance(x_maybe_fp8, Float8Tensor):
        x_fp8 = x_maybe_fp8
    else:
        x_fp8 = to_fp8(x_maybe_fp8, scaling_granularity, ...)
    # repeat for w
    # call torch._scaled_mm

Furthermore, to keep things readable / debuggable, it would be good to:

  1. be able to print tensors before/after quantization
  2. be able to associate tensors to their parent module, and the specific gemm in fwd/bwd in that module

To do the above, we'll need to pass around metadata such as module FQNs.

This PR implements Option 2 as IMO this is more readable/debuggable.

Test plan:

// all green
./test/test_everything.sh