pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.58k stars 173 forks source link

[Feature Request] Fused fp8 matmul kernel (quant + dequant + matmul) #752

Open qingquansong opened 2 months ago

qingquansong commented 2 months ago

Hey, team, AO provides awesome FP8 support with torch compile to get speed and memory improvement, however since torch compile is not always easily applicable for some models such as MoE HF implementation, which requires some changes to make like in the GPT-fast code (also this changes will make the model bother on some memory issue when implementing in that way for prefill stage since it's better for fast decoding)

Currently, AO does not directly offer the fusion of the quant and dequant part into the matmul (scaled_mm) so have to rely on torch compile to fuse and gain the speedup (otherwise will be affected by the quant/dequant stage see #685, #704 and here.

And Transformer Engine has the custom kernel to do so I feel like a small missing piece that could be added is the fused version of this kernel (either triton or torch ops with cuda implementation) so we can get the speed up from AO fp8linear even without using the torch compile. This will be both beneficial for training and inference. Thank you!

msaroufim commented 2 months ago

This is an interesting issue, here's some half baked thoughts on the topic

Today AO adoption is contingent on compile growth, basically right now we act as an addon or expansion pack that gives people even more speedups or VRAM savings if they're using our compiler

However there are alternate deployment strategies we could explore

  1. What if we could code generate the fused fp8 matmul kernel, look at the code using TORCH_LOGS="output_code" and then just use that. Conceptually that's my understanding of what AOT inductor is @desertfire - there might be some care we need to take in terms of making sure the kernels generalize to a variety of inputs
  2. Alternatively we could just write the kernel, this feels acceptable for narrow use cases that are important like fp8 matmul @vkuzo
  3. Finally we could explore a flex attention like deployment strategy where the fp8 matmul kernel is called as if it was an eager kernel but under the hood calls torch.compile, this wouldnt completely solve cold start problems but it would avoid having to have the entire model be torch.compilable @drisspg might have more thoughts

Chatted with @vkuzo about this offline - feedback was 1 and 2 are the same, there might be issues in terms of how do we scale to more models with different activations and normalizations but it's conceptually the simplest since it doesn't interact with other PyTorch subsystems. A longer term approach would be to perhaps have an eager like API that takes as input a normalization and activation to code generate a fused kernel and while that's more flexible it's conceptually more complex

vkuzo commented 1 month ago

Here is something we could do: provide fast fused eager mode bf16_to_float8 quantization kernels (via triton), configurable and off by default. This would provide a speedup compared to the current eager mode quantization kernels, but would not provide any fusion opportunities with surrounding ops and thus still not reach optimal performance.

We have been avoiding hand-writing fused kernels for all the possible ops in (prev_op -> bf16_to_float8) by hand because that approach is expensive to scale to a wide variety of models/use cases/config options, but perhaps just the quantize kernel could be a useful middle ground.

Would love to hear thoughts on how useful ^ would be.

qingquansong commented 1 month ago

Hey @vkuzo thank you very much for the suggestions! I think having a kernel as walkaround would be useful. Do you happen to have any rough estimation of the speedup and mem improvement on 8b/70b llama training with that supported? (suppose we use 1 node H100 for 8b training + 2k-4k context length.