Open lopuhin opened 6 months ago
Yeah, batching isn't "free" the same way it is with FP16 weights. With FP16 weights you have 4x as much memory access to begin with, and you don't have to spend extra compute to dequantize weights. So quantized kernels, as a rule, are much closer to being compute bound, and as you double the batch size you also double the compute requirement.
The Marlin kernel manages to combine enough low-level optimizations to stay memory-bound up to much higher batch sizes. They've reduced the computation per weight, jumped through a lot of hoops to hide the latency from that computation and also offloaded efficiently to tensor cores to make use of the extra compute they provide. Sadly their kernel is also very inflexible. The optimizations are CUDA specific (no support for ROCm) and Ampere+ only, and it would take a lot of work to adapt them to the permuted, mixed-bitrate, variable group size format of EXL2.
The multiple caches approach attempts to gain some of the benefits of batching while still doing attention on individual sequences. It's not a very efficient way to go about it, but it does have the benefit that you can add or remove batches dynamically.
I'm working on a better solution for dynamic batching with paged attention which shouldn't have this overhead. There may also be improvements to batched performance but I have too many plates spinning to give you a timetable for those.
Hi, is there any updates? Thanks!
When scaling the batch size from 1 to a small number, say 8, I'd naively expect the generation performance to scale quite well, as we're still firmly in memory bound regime. But I observe two things which I didn't expect:
scaling of batched inference is lower than what I'd naively expect (e.g. 2.38x when going from batch size 1 to 8)(edit: upon reading up more, e.g. here I realize it's probably expected)I'm curious if anyone is also observing similar behavior, and what is preventing better scaling? In practice to implement inflight batching I have to use multiple caches, so its performance is more relevant to my use-case.
Below are results I obtained:
Both scripts are here https://gist.github.com/lopuhin/c08fd724aa8ca71f37ecac219dc8a608
Sequences in multiple_caches.py are a bit shorter, but I also checked that scaling of batched_inference.py is the same with much shorter sequences.
All measurements done on llama3 exl2 4 bit on 2080 ti, CUDA 11.8, latest stable release of exllamav2. I also repeated a few measurements on T4 GPU and getting similar scaling.
P.S. Thanks for a wonderful library!