aredden / flux-fp8-api

Flux diffusion model implementation using quantized fp8 matmul & remaining layers use faster half precision accumulate, which is ~2x faster on consumer devices.
Apache License 2.0
105 stars 12 forks source link

Where is the code about "remaining layers use faster half precision accumulate"? #10

Open goldhuang opened 1 week ago

goldhuang commented 1 week ago

Flux diffusion model implementation using quantized fp8 matmul & remaining layers use faster half precision accumulate, which is ~2x faster on consumer devices. Hello there! Thanks for sharing your quantization implementation of Flux! I have a question about "remaining layers use faster half precision accumulate". Could you help to point out the lines that enable "faster half precision accumulate" in the repo? Thanks in advance!

aredden commented 1 week ago

It's the CublasLinear layers. It's a repo I made which allows matmuls to run with half precision accumulate within the matmul kernel- which doubles the tflops for most consumer gpus. The source is here- https://github.com/aredden/torch-cublas-hgemm - so, wherever you see CublasLinear replacements happening- I think it's actually in the float8_quantize.py file, that's where that occurs.

goldhuang commented 1 week ago

@aredden Thanks for your detailed answer! I have 2 follow-up questions now. 1) Why do you only replace linear layers in single/double block with fp8? 2) Why does CublasLinear only support float16?

aredden commented 1 week ago
  1. You can optionally quantize the others by setting "quantize_flow_embedder_layers": true, but it does pretty considerably reduce quality and doesn't add much extra vram or increase it/s. The non-single-or-double-block layers only make up for ~2% of the models actual weights, but have a considerable effect on quality.

  2. Well if you check out the ADA whitepaper, you'll find that the top theoretical tflops for fp16 w/ fp32 accumulate is ~160 for 4090, but 330 for fp16 w/ fp16 accumulate. Unfortunately you cannot use fp16 accumulate with anything other than fp16 tensors, and bf16 cannot be used as accumulation datatype so the only way to achieve those tflops on consumer gpus is via fp16. It's actually the same speed as fp8 matmul!