bigscience-workshop / petals

🌸 Run LLMs at home, BitTorrent-style. Fine-tuning and inference up to 10x faster than offloading
https://petals.dev
MIT License
8.89k stars 489 forks source link

Optimize the Falcon block for inference #500

Closed mryab closed 10 months ago

mryab commented 10 months ago

This PR attempts to optimize the inference of Falcon models in the single-token setup by reducing the majority of Python overhead and making several assumptions about the setup. Specifically,

  1. Layer normalization, QKV projection (with splitting) and rotary embeddings are executed through CUDA graphs, which reduces most overhead related to small kernel launches
  2. If no sin/cos tensors are cached by the rotary embedding layer, we cache them for 8192 tokens (INFERENCE_MAX_LENGTH) during the first forward pass. In general, it should be beneficial to always run a max-length sequence before starting a block, but this is a question for another PR

The PR also adds a small test to ensure that the results (without quantization) of the block before and after quantization indeed match.

borzunov commented 10 months ago

Benchmarks: this PR gives +40% to inference speed

Model: Falcon-40B GPU: A6000 Ada

main @ d40eb6c, 3 runs

Sep 04 11:52:03.072 [INFO] Inference throughput: 756.6 tokens/sec per block (1 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)
Sep 04 11:52:13.657 [INFO] Forward pass throughput: 61338.3 tokens/sec per block (1024 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)

Sep 04 11:53:08.637 [INFO] Inference throughput: 776.9 tokens/sec per block (1 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)
Sep 04 11:53:19.217 [INFO] Forward pass throughput: 61292.5 tokens/sec per block (1024 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)

Sep 04 11:54:06.825 [INFO] Inference throughput: 759.0 tokens/sec per block (1 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)
Sep 04 11:54:17.416 [INFO] Forward pass throughput: 61322.3 tokens/sec per block (1024 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)

optim_falcon @ 52baffb, 3 runs

Sep 04 11:48:32.613 [INFO] Inference throughput: 1044.6 tokens/sec per block (1 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)                                               
Sep 04 11:48:43.189 [INFO] Forward pass throughput: 62396.0 tokens/sec per block (1024 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)  

Sep 04 11:49:31.860 [INFO] Inference throughput: 1075.5 tokens/sec per block (1 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)
Sep 04 11:49:42.453 [INFO] Forward pass throughput: 61365.4 tokens/sec per block (1024 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)

Sep 04 11:50:28.453 [INFO] Inference throughput: 1068.0 tokens/sec per block (1 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)
Sep 04 11:50:39.046 [INFO] Forward pass throughput: 61758.3 tokens/sec per block (1024 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)