bigscience-workshop / petals

🌸 Run LLMs at home, BitTorrent-style. Fine-tuning and inference up to 10x faster than offloading
https://petals.dev
MIT License
9.11k stars 512 forks source link

Optimize LLaMA for inference #513

Closed mryab closed 10 months ago

mryab commented 1 year ago

Similarly to https://github.com/bigscience-workshop/petals/pull/500, this PR aims to speed up Llama models by making the following optimizations compared to the original Transformers implementation:

Additionally, this PR introduces a petals.utils.cuda_graphs.make_inference_graphed_callable function that converts any inference-mode callable into its CUDA graph version. This is meant to serve as an alternative for torch.cuda.make_graphed_callables that does not attempt to build a graph for the backward pass: inference is called in inference_mode, so the original function fails (that's why the Falcon PR used custom graph tracing as well)

borzunov commented 1 year ago

Benchmarks: this PR gives +44% to inference speed

Model: Stable Beluga 2 (70B) GPU: A6000 Ada

main @ a2484b3:

Sep 20 03:08:50.845 [INFO] Inference throughput: 750.6 tokens/sec per block (1 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)                      
Sep 20 03:09:04.064 [INFO] Forward pass throughput: 48486.8 tokens/sec per block (1024 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)

optimize_llama @ f332b0e7a2ef2635b683d6da18aa59184314c460:

Sep 20 03:10:13.415 [INFO] Inference throughput: 1078.9 tokens/sec per block (1 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)
Sep 20 03:10:26.583 [INFO] Forward pass throughput: 48003.5 tokens/sec per block (1024 tokens/batch, NVIDIA RTX 6000 Ada Generation GPU, bfloat16, quantized to nf4)
poedator commented 10 months ago

When testing with TinyLlama for some unrelated thing I caught error Caught too many indices for tensor of dimension 2 It happened in this line cos = cos[:, :, kv_seq_len - q_len :] https://github.com/bigscience-workshop/petals/pull/513/files#diff-492af4f870c9613ff6b5fce973ddd1d75bf135b30f40a7cb83f455c4f0e72ea6R87 Env: Tranformers 4.35.2 ref to test run https://github.com/bigscience-workshop/petals/actions/runs/6950529337/job/18910867509?pr=545 - see line 2755 @mryab ?