This PR makes an ~7% optimization of the inference throughput (measured on a single A100-80GB) by merging the query/key/value projections into a single large matrix multiplication. This reduces the overhead of launching several matmul kernels, which turns out to be substantial for single-sequence single-token inference steps. Also, this code adds a --throughput dry_run option to estimate throughput without starting a server.
Sample results from running experiments with and without the optimization (the command in each case is CUDA_VISIBLE_DEVICES=0 python -m petals.cli.run_server petals-team/StableBeluga2 --throughput dry_run):
This PR makes an ~7% optimization of the inference throughput (measured on a single A100-80GB) by merging the query/key/value projections into a single large matrix multiplication. This reduces the overhead of launching several matmul kernels, which turns out to be substantial for single-sequence single-token inference steps. Also, this code adds a
--throughput dry_run
option to estimate throughput without starting a server.Sample results from running experiments with and without the optimization (the command in each case is
CUDA_VISIBLE_DEVICES=0 python -m petals.cli.run_server petals-team/StableBeluga2 --throughput dry_run
):Current code (branch https://github.com/bigscience-workshop/petals/tree/no_qkv_merge):
Code from this PR: