pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
211 stars 20 forks source link

[RFC] Float8 Inference #314

Closed drisspg closed 3 months ago

drisspg commented 4 months ago

RFC: Float8 Inference

Objective

We want to provide an easy mechanism to utilize FP8 in inference, and see both decreased memory usage and performance gains on hardware that supports native FP8 computation. We would like the API to require minimal model rewrites. We also want it to be configurable in such a way as to provide multiple levels of scaling granularity with their own accuracy/performance trade-offs. The solution should be composable with other inference components in the PyTorch ecosystem:

This solution is targeting server-side GPU inference. It is not currently focused on supporting edge or CPU inference.

Background

Float8 inference can be used to reduce memory usage and improve computational efficiency. By using FP8 instead of higher precision formats, we can achieve significant speedups and memory savings with minimal loss in accuracy. The memory saving is unique to float8 inference as opposed to float8 training. For inference, the weights are static and thus do not need the higher precision during weight updates.

Proposal

Float8InferenceLinear Module

We propose a new Float8InferenceLinear module that extends nn.Linear with Float8 quantization capabilities:

class Float8InferenceLinear(torch.nn.Linear):
    def __init__(
        self,
        quant_config: QuantConfig,
        forward_config: ScaledMMConfig,
        scaling_granularity: Optional[ScalingGranularity],
        in_features: int,
        out_features: int,
        bias: bool = True,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        # ... implementation ...

This module handles the quantization of weights and activations based on the provided configuration. This module was landed in this PR: #287. It is designed to replace a pre-trained nn.Linear module in an existing model and statically convert the weight to FP8. By default, we do this in E4M3 format.

It provides configuration options via the QuantConfig class to encapsulate various quantization settings:

@dataclass(frozen=True)
class QuantConfig:
    activation_casting: ActivationCasting
    static_quantization_scale: Optional[torch.Tensor] = None
class ActivationCasting(Enum):
    """Types of quantization to perform on the activations

    WEIGHT_ONLY: Only quantize the weight, no activation casting, weight will be dequantized in the forward pass
    STATIC: Activation is quantized during model initialization with a static scale
    DYNAMIC: Activation is quantized during forward pass with a dynamic scale calculated from the input activation
    """

    # TODO: A better name would be NONE, we should unify this with torchao
    WEIGHT_ONLY = auto()
    DYNAMIC = auto()
    STATIC = auto()

The main configuration options are captured in the ActivationCasting enum:

Top-level API

We propose a top-level API for quantizing models:

def quantize_to_float8(
    module: nn.Module,
    quant_config: QuantConfig,
    *,
    skip_fqn_list: Optional[List[str]] = None,
    use_fast_accum: bool = True,
    scaling_granularity: Optional[ScalingGranularity] = None,  # Part of Future proposal
) -> Optional[nn.Module]:
    # ... implementation ...

This function allows users to easily convert their models to use Float8 inference.

An example of how this can be used on a Hugging Face model can be found in this PR in TorchAO

Proposed Extensions

Scaling Granularity

Currently, we only support TensorWise scaling. Concretely, this is done by calculating the max(abs(Tensor)) and utilizing this value to compute the Float8Tensor scale. However, due to outlier values in activations, this can have large quantization error. As well, calculating a global reduction across the entire activation tensor can be relatively slow.

Therefore, we want to add the option to specify different types of scaling granularities.

The scaling_granularity parameter determines how scales are computed:

We recently added Axiswise scaling support to _scaled_mm in this PyTorch PR: #128989. As well, I have a worked PR stack showing how Axiswise scaling can be implemented in Float8Experimental: https://github.com/pytorch-labs/float8_experimental/pull/305

We would like to continue generalizing the scaling granularity to:

Design Details

Tensor Subclass Usage

The implementation utilizes Float8Tensors to encapsulate the scaling as well as dispatch to _scaled_mm instead of torch.mm. This is not the only way this could be implemented. Since we do not have the autograd constraint that backpropagating grads must match the dtype of the tensor in the forward, we are free to desugar the Float8Tensor into its constituents, store them on the module, and use them in the forward. However, using the tensor subclass, allows us to re-use similar components between training and inference, but it does have downsides:

Performance

Compile

As with the rest of this project, we heavily rely on the compile stack to generate efficient and fused casting code. We do actually see some performance gains on heavily compute-bound models, but in general, we require torch.compile for competitive performance.

Export

Currently, it is not possible to run torch.export + AOTI with the publicly available export APIs. However, this PR: https://github.com/pytorch-labs/float8_experimental/pull/295 demonstrates that it is possible. There are plans this half for the export team to make export of nn.modules with subclasses as weights available in the public API.

Limitations and Future Work

Extend ScalingGranularity

Composition with other dtypes/techniques

Standardize on TorchAO APIs

Non-H100 GPU Support

Dynamic Shapes

Other Module Support

-While Linear weights take up the majority of model size and compute, other operations can still be amenable to the compute gains from FP8

Examples

# Example usage of the proposed API
model = MyLargeModel()
quant_config = QuantConfig(ActivationCasting.DYNAMIC)
quantized_model = quantize_to_float8(
    model,
    quant_config,
    scaling_granularity=ScalingGranularity.AxisWise
)

quantized_model = torch.compile(quantized_model)
# Use the quantized model for inference
input_tensor = torch.randn(1, 1024, 1024, dtype=torch.bfloat16, device="cuda")
output = quantized_model(input_tensor)

Open Questions

  1. Should we provide more granular control over which layers are quantized? This is possible today using FQNs but not sure if TorchAO has ideas on top-level UX.
  2. How can we best handle models with custom or non-standard linear layers?
  3. What additional tools or utilities might be needed to help users debug and optimize their quantized models?
  4. Quantization Error Reducing Techniques: Techniques like HQQ are utilized to reduce quantization error. It is unlikely that the existing _scaled_mm kernel can support this use case. Is that a problem?

Conclusion

This RFC proposes significant enhancements to Float8 inference in PyTorch, aiming to provide a more flexible, efficient, and user-friendly framework for quantization. By supporting various scaling granularities and quantization strategies, we can cater to a wide range of use cases and potentially unlock substantial performance improvements for many models.

Additional Details

Utilizing this script: https://gist.github.com/drisspg/d7ae2134fbb6ca369c4817853c3352fa

Results for batch_size=1, num_tokens=128:
+----------------------------+-------------+-------------------+----------------+
| Variant                    |   Time (μs) | Speedup vs BF16   |   SQNR vs BF16 |
+============================+=============+===================+================+
| BF16                       |      211.2  | 1.00x             |         inf    |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_TensorWise     |      151.04 | 1.40x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_TensorWise      |      138.7  | 1.52x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_TensorWise |      460.01 | 0.46x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_AxisWise       |      137.68 | 1.53x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_AxisWise        |      131.39 | 1.61x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_AxisWise   |      459.72 | 0.46x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+

Results for batch_size=1, num_tokens=1024:
+----------------------------+-------------+-------------------+----------------+
| Variant                    |   Time (μs) | Speedup vs BF16   |   SQNR vs BF16 |
+============================+=============+===================+================+
| BF16                       |      642.22 | 1.00x             |         inf    |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_TensorWise     |      396.68 | 1.62x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_TensorWise      |      364.04 | 1.76x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_TensorWise |      871.38 | 0.74x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_AxisWise       |      390.63 | 1.64x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_AxisWise        |      369.72 | 1.74x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_AxisWise   |      868.9  | 0.74x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+

Results for batch_size=32, num_tokens=128:
+----------------------------+-------------+-------------------+----------------+
| Variant                    |   Time (μs) | Speedup vs BF16   |   SQNR vs BF16 |
+============================+=============+===================+================+
| BF16                       |     2567.15 | 1.00x             |         inf    |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_TensorWise     |     1535.65 | 1.67x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_TensorWise      |     1405.36 | 1.83x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_TensorWise |     2783.9  | 0.92x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_AxisWise       |     1487.35 | 1.73x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_AxisWise        |     1420.56 | 1.81x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_AxisWise   |     2786.66 | 0.92x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+

Results for batch_size=32, num_tokens=1024:
+----------------------------+-------------+-------------------+----------------+
| Variant                    |   Time (μs) | Speedup vs BF16   |   SQNR vs BF16 |
+============================+=============+===================+================+
| BF16                       |     21087.9 | 1.00x             |         inf    |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_TensorWise     |     12172.4 | 1.73x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_TensorWise      |     11220.7 | 1.88x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_TensorWise |     21209.9 | 0.99x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_AxisWise       |     12393.6 | 1.70x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_AxisWise        |     11853.9 | 1.78x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_AxisWise   |     21227.7 | 0.99x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+

Results for batch_size=64, num_tokens=2048:
+----------------------------+-------------+-------------------+----------------+
| Variant                    |   Time (μs) | Speedup vs BF16   |   SQNR vs BF16 |
+============================+=============+===================+================+
| BF16                       |     86532.6 | 1.00x             |         inf    |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_TensorWise     |     49520.9 | 1.75x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_TensorWise      |     47816.8 | 1.81x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_TensorWise |     86674.2 | 1.00x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_AxisWise       |     68645.7 | 1.26x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_AxisWise        |     54025.8 | 1.60x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_AxisWise   |     85562.3 | 1.01x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+
msaroufim commented 4 months ago

Thanks @drisspg! For all the configuration options my expectations is you should follow the template set by Jerry as in get things working with the quantize() API and if that doesn't work you 2 should talk more - my comments apply to everything below the Design Details section

  1. Regarding subclasses: could you be more specific on the lack of familiarity and challenges with export? Is this referring to the unwrap tensor class shenanigans? Lack of familiarity is more of an issue for external developers not the core team
  2. _scaled_mm only supports TensorWise and AxisWise scaling is only available: which other non supported ones did you have in mind?
  3. Non H100 GPU support is fine: presumably people can still run the code without speedups or will it crash?
  4. Other module support: when you say fused fp8 SDPA wdym exactly, as in fusing the dequant in prologue or did you have something else in mind? I guess a bit confused since I thought fp8 didn't require upcasting to a higher bitwidth for matmuls
  5. KV cache support: This is super compelling to me, we already have an fp8 adam and an fp8 kv cache just makes so much sense for long sequence length work
  6. On the open questions: So my expecation here is you provide code as is that's fast on examples you try out, and as people try larger models I'd expect accuracy to become more of an issue in which case if the code is in pure python then you can always point people to where they can write more accurate algorithms by copy pasting your code without having to worry about that now
drisspg commented 4 months ago

@msaroufim Great feedback, thank you!

For the quantize_() API question: The main difference, from my understanding, is that you have a tensor subclass per strategy: https://github.com/pytorch/ao/blob/05038a1f613b7a1f6b0aec84252021e2be98adff/torchao/quantization/quant_api.py#L282C1-L285C72. Here, we are using one tensor subclass and defining extra module state and forward logic to do this. I think that if in the future we want to support the "full subclass" API, there are a few options:

  1. Write out each float8 variant and copy over the parts needed.
  2. Subclass the main Float8Tensor and override the mm logic.
  3. Design Generic Wrapper subclasses that can do Weight only, Dynamic, Static-Activation with some interface that all subclasses API, similar to the quantize/dequantize API.

I prefer option 3, but I think in the meantime it's okay for this solution to be distinct.

Okay, back to the comments 😉

  1. a) Regarding lack of familiarity, this was more stemming from a comment from @vkuzo saying that in general it's harder to interpret what an nn.module is doing under the hood when it uses subclasses. I agree with this and was more remarking that this may put off more advanced users who want to fully understand all the details.

  2. b) Yup, the export comment was about the Unwrap Subclass annoyance. I spoke with @tugsbayasgalan about this, and they have plans to encode this logic into the export flow, making it no longer a responsibility of the export caller.

  3. In particular, the next two we want to support are the GroupWise and BlockWise. GroupWise FP8 is essentially the MxFP8 format, with the only difference being that the scale is e8 type for mx as opposed to float32 here.

  4. You should be able to run the weight-only code on non-H100 fine; however, attempting to run the FP8 compute paths will error with an 'unsupported device' message.

  5. That was confusing; I meant more of a FlashAttention-like algorithm as opposed to an unfused variant where it's easier to inject FP8 scaling and unscaling around the softmax.

  6. Agreed 😊

vkuzo commented 4 months ago

This is great!

On a high level, I'd also be interested to hear about how this would compose with distributed, and on the motivation to use a module (vs a subclass, or both a module and a subclass) as the user visible object.

A couple of technical comments:

def quantize_to_float8(

There are also use cases for a two stage API, where first the user calibrates with sample data (for example, for activation scales) and then runs inference using the calibration data. Might be good to align with how torchao is planning to organize this. I noticed the RFC defers this until later, but the code does mention supporting static scales for activations.

    quant_config: QuantConfig,

Eventually the quantization settings should be configurable per-module instead of being per-model, it's pretty common to change the settings based on FQN, weight shape, etc.

        forward_config: ScaledMMConfig,
        scaling_granularity: Optional[ScalingGranularity],

One could argue that these could live inside QuantConfig

jerryzh168 commented 4 months ago

Thanks for the detailed note, my main question is for float8 inference, can float8 tensor subclass be expressed as a primitive dtype like (torch.float8_e4m3) + AffineQuantizedTensor (since scale and granularity logic are already implemented there)

drisspg commented 4 months ago

@jerryzh168 I think yes, it if you have no ZeroPoint for AQT (or at least if you dont do any actual addition when the zero point is zero since native addition is not supported for float8_e4m3) . However, I do think that the constraints of the scaled_mm kernel would leak pretty heavily into AQT if we tried to merge all into one.

jerryzh168 commented 4 months ago

@drisspg zero_point could be optional I think (https://github.com/pytorch/ao/blob/05038a1f613b7a1f6b0aec84252021e2be98adff/torchao/quantization/quant_primitives.py#L146). what do you mean by constraints of scaled_mm kernel leaking into AQT? AQT also supports dispatching to different kernels based on different conditions

gau-nernst commented 4 months ago

I'm also interested in FP8 matmul inference. A question. I see that triton matmul kernel can work with FP8, and torch.compile can also generate FP8 triton matmul. What is the performance difference between _scaled_mm() (which is backed by CuDNN or cutlass I assume) and triton FP8 matmul? I can see that the benefits of using triton FP8 matmul would be allowing more flexible scaling, especially for non-H100. Using existing AQT facilities in torchao, the addition of FP8 inference (using torch.mm + torch.compile -> triton) wouldn't be too challenging.

vkuzo commented 3 months ago

discussion moved to https://github.com/pytorch/ao/issues/574