Open qingquansong opened 2 months ago
This is an interesting issue, here's some half baked thoughts on the topic
Today AO adoption is contingent on compile growth, basically right now we act as an addon or expansion pack that gives people even more speedups or VRAM savings if they're using our compiler
However there are alternate deployment strategies we could explore
Chatted with @vkuzo about this offline - feedback was 1 and 2 are the same, there might be issues in terms of how do we scale to more models with different activations and normalizations but it's conceptually the simplest since it doesn't interact with other PyTorch subsystems. A longer term approach would be to perhaps have an eager like API that takes as input a normalization and activation to code generate a fused kernel and while that's more flexible it's conceptually more complex
Here is something we could do: provide fast fused eager mode bf16_to_float8 quantization kernels (via triton), configurable and off by default. This would provide a speedup compared to the current eager mode quantization kernels, but would not provide any fusion opportunities with surrounding ops and thus still not reach optimal performance.
We have been avoiding hand-writing fused kernels for all the possible ops in (prev_op -> bf16_to_float8)
by hand because that approach is expensive to scale to a wide variety of models/use cases/config options, but perhaps just the quantize kernel could be a useful middle ground.
Would love to hear thoughts on how useful ^ would be.
Hey @vkuzo thank you very much for the suggestions! I think having a kernel as walkaround would be useful. Do you happen to have any rough estimation of the speedup and mem improvement on 8b/70b llama training with that supported? (suppose we use 1 node H100 for 8b training + 2k-4k context length.
Hey, team, AO provides awesome FP8 support with torch compile to get speed and memory improvement, however since torch compile is not always easily applicable for some models such as MoE HF implementation, which requires some changes to make like in the GPT-fast code (also this changes will make the model bother on some memory issue when implementing in that way for prefill stage since it's better for fast decoding)
Currently, AO does not directly offer the fusion of the quant and dequant part into the matmul (scaled_mm) so have to rely on torch compile to fuse and gain the speedup (otherwise will be affected by the quant/dequant stage see #685, #704 and here.
And Transformer Engine has the custom kernel to do so I feel like a small missing piece that could be added is the fused version of this kernel (either triton or torch ops with cuda implementation) so we can get the speed up from AO fp8linear even without using the torch compile. This will be both beneficial for training and inference. Thank you!