turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.19k stars 234 forks source link

Scaling inference throughput when increasing the batch size #450

Open lopuhin opened 1 month ago

lopuhin commented 1 month ago

When scaling the batch size from 1 to a small number, say 8, I'd naively expect the generation performance to scale quite well, as we're still firmly in memory bound regime. But I observe two things which I didn't expect:

I'm curious if anyone is also observing similar behavior, and what is preventing better scaling? In practice to implement inflight batching I have to use multiple caches, so its performance is more relevant to my use-case.

Below are results I obtained:

  1. scaling with batched_inference.py (only generated tokens, but time includes a bit of prefill, I also fixed performance reporting at the end, uncommented warmup and left only 16 prompts):
  1. scaling with multiple_caches.py by varying max_parallel_seqs (looking only at generation without prefill, guarded by torch.cuda.synchronize, I also increased the number of prompts to 16):

Both scripts are here https://gist.github.com/lopuhin/c08fd724aa8ca71f37ecac219dc8a608

Sequences in multiple_caches.py are a bit shorter, but I also checked that scaling of batched_inference.py is the same with much shorter sequences.

All measurements done on llama3 exl2 4 bit on 2080 ti, CUDA 11.8, latest stable release of exllamav2. I also repeated a few measurements on T4 GPU and getting similar scaling.

P.S. Thanks for a wonderful library!

turboderp commented 1 month ago

Yeah, batching isn't "free" the same way it is with FP16 weights. With FP16 weights you have 4x as much memory access to begin with, and you don't have to spend extra compute to dequantize weights. So quantized kernels, as a rule, are much closer to being compute bound, and as you double the batch size you also double the compute requirement.

The Marlin kernel manages to combine enough low-level optimizations to stay memory-bound up to much higher batch sizes. They've reduced the computation per weight, jumped through a lot of hoops to hide the latency from that computation and also offloaded efficiently to tensor cores to make use of the extra compute they provide. Sadly their kernel is also very inflexible. The optimizations are CUDA specific (no support for ROCm) and Ampere+ only, and it would take a lot of work to adapt them to the permuted, mixed-bitrate, variable group size format of EXL2.

The multiple caches approach attempts to gain some of the benefits of batching while still doing attention on individual sequences. It's not a very efficient way to go about it, but it does have the benefit that you can add or remove batches dynamically.

I'm working on a better solution for dynamic batching with paged attention which shouldn't have this overhead. There may also be improvements to batched performance but I have too many plates spinning to give you a timetable for those.