turboderp / exllama

A more memory-efficient rewrite of the HF transformers implementation of Llama for use with quantized weights.
MIT License
2.74k stars 215 forks source link

Latency grows substantially as batch size increases, even with small batch sizes #202

Open joehoover opened 1 year ago

joehoover commented 1 year ago

Thanks for the wonderful repo, @turboderp!

I'm benchmarking latency on an A100 and I've observed latency increasing substantially as I increase batch size–to much larger degree than I'm used to (logs included below):

I'd love to know if I'm missing something or if this is expected!

Setup

I'm benchmarking with The Bloke's gptq_model-4bit-128g llama-2-13B-chat-GPTQ checkpoint.

I'm using test_benchmark_generation.py with some minimal modifications to run these benchmarks.

I'm instantiating cache with batch size and I'm warming up with a batch of ids.

I'm generating tokens like:

ids = torch.randint(0, 31999, (1, max_seq_len - gen_tokens)).repeat(batch_size, 1).cuda()

...

for i in range(gen_tokens):

    logits = logits[:, -1, :]
    id_per_batch = torch.argmax(logits, dim=-1)
    assert id_per_batch.shape == (batch_size,), f"{id_per_batch.shape} != {(batch_size,)}"
    next_id_per_batch = id_per_batch.unsqueeze(-1)
    sequence = torch.cat((sequence, next_id_per_batch), dim = -1)
    logits = next_logits(next_id_per_batch, lora)

bs=1

 ** Batch size: 1
         ** Latency: 46.10 tokens/second
         ** Throughput: 46.10 tokens/second
         ** Total time: 2.776369094848633
 -- Generating 128 tokens, 4 token prompt...
 ** Batch size: 1
         ** Latency: 54.13 tokens/second
         ** Throughput: 54.13 tokens/second
         ** Total time: 2.364816427230835

bs=2

-- Generating 128 tokens, 1920 token prompt...
 ** Batch size: 2
         ** Latency: 33.04 tokens/second
         ** Throughput: 66.08 tokens/second
         ** Total time: 3.87422776222229
 -- Generating 128 tokens, 4 token prompt...
 ** Batch size: 2
         ** Latency: 41.16 tokens/second
         ** Throughput: 82.31 tokens/second
         ** Total time: 3.1100616455078125

bs=4

 -- Generating 128 tokens, 1920 token prompt...
 ** Batch size: 4
         ** Latency: 23.13 tokens/second
         ** Throughput: 92.50 tokens/second
         ** Total time: 5.535064220428467
 -- Generating 128 tokens, 4 token prompt...
 ** Batch size: 4
         ** Latency: 28.07 tokens/second
         ** Throughput: 112.26 tokens/second
         ** Total time: 4.560673952102661
turboderp commented 1 year ago

The kernels are very specifically optimized for matrix-vector operations (batch size = 1). It also does well on matrix-matrix by reconstructing full-precision matrices on the fly and relying on cuBLAS. The in-between territory is problematic, but I guess the question is what sort of throughput you would expect. (?)

fxmarty commented 1 year ago

The kernels from NVIDIA folks at https://github.com/tlc-pack/cutlass_fpA_intB_gemm are probably interesting in the batched scenario.