pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.56k stars 169 forks source link

high throughput inference #663

Open msaroufim opened 3 months ago

msaroufim commented 3 months ago

Was chatting with @Chillee about our plans in AO today and he mentioned we should be focusing on a few concrete problems like

  1. Demonstrate compelling perf for fp8 gemm at a variety of batch sizes.
  2. Demonstrate compelling perf for weight only int8 gemm at a variety of batch sizes.
  3. Demonstrate compelling perf for weight only intX gemm at low batch sizes.
  4. Demonstrate compelling perf for weight intX, activation fp8 at a variety of batch sizes.

We could as a baseline extend gpt-fast to work with bs=n w/o doing any kv cache management work and measure perf there. Copying feedback as is, open to discussing more and adding more details as time progresses

EDIT: gpt-fast already has a batched generation branch by Horace https://github.com/pytorch-labs/gpt-fast/tree/batched_generation

msaroufim commented 3 months ago

@HDCharles on the int8 work @vkuzo on fp8 @vayuda and @jerryzh168 on intx

jeromeku commented 3 months ago

@msaroufim

Would be interesting to bench against something like QoQ, which implements W4A8KV4 (int8 GEMM) using a nested quantization scheme and neat kernel-level optimizations.

vkuzo commented 3 months ago

Demonstrate compelling perf for fp8 gemm at a variety of batch sizes.

Note that I'm putting up a PR soon for a quick roofline estimator for float8 gemm + overhead specific to training to see for which M, K, N float8 is faster than bfloat16, it would be easiliy extendable to inference at a later time.

Demonstrate compelling perf for weight intX, activation fp8 at a variety of batch sizes.

While this is possible technically, I'm not sure I understand the value, would be interested to learn more.