pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.35k stars 484 forks source link

Bandwidth achieved for INT8 is much smaller than FP16 #99

Open yafehlis opened 4 months ago

yafehlis commented 4 months ago

I run CodeLlama 7B: when I use FP16, bandwidth achieved is 700 GB/s; however, when I use INT8, it is 197 GB/s. I run the model on one AMD MI210 GPU. Why is bandwidth achieved lower using INT8? @kit1980 @msaroufim @yifuwang @huntzhan Thanks, Yao Fehlis (AMD)

yafehlis commented 4 months ago

The above numbers are acquired using PyTorch 2.1 ROCm 5.6. When I used PyTorch 2.2 ROCm 5.7, the results are much better. However, INT8 is still lower than FP16. I have seen your results on the website for that trend as well. Do you know why?

HDCharles commented 4 months ago

the quantization overhead is to blame, at least for the numbers in the README. You're doing the same amount of computation in the matmul but also have to decompress the loaded int8 weight. While you gain a lot of e2e perf from ~2x faster load speed, you lose a bit from the overhead but in terms of Memory Bandwidth you don't gain anything from the loading half as much data twice as fast. Its unclear why it'd be so bad for AMD GPU's, the int8 weight only kernel is a bit finicky though so it might be handled differently on AMD.

I'd recommend also not using Fp16, the quantized model is set to operate in bf16 and that might not be playing well with fp16.

Finally, the option torch._inductor.config.use_mixed_mm = True changes the kernel to a triton one which may perform differently on your hardware.

Chillee commented 4 months ago

Yeah, generally speaking, when you go from int8 to int4, the "theoretical" speedup should be 2x. But in practice, due to Amdahl's law type reasons, you end up getting bottlenecked by other factors and so the actual speedup is less than 2x. And that results in a lower "memory bandwidth"