IBM / text-generation-inference

IBM development fork of https://github.com/huggingface/text-generation-inference
Apache License 2.0
57 stars 30 forks source link

Fix logic for determining the number of cache blocks #98

Closed tdoublep closed 5 months ago

tdoublep commented 5 months ago

Motivation

When we deploy spec decoding in prod., we are frequently seeing the servers running out of free blocks. We have determined that this is due to two issues:

  1. The constraint on SPECULATOR_MAX_BATCH_SIZE is not enough to avoid running into memory pressure due to speculation - we need to able ensure that we do not speculate on batches that may have a small "size" but very large weight.
  2. The computation of the number of blocks is very wrong in most cases.

Modifications

  1. I have introduced an additional constraint that says we should only speculate on batches with weight up to 75% of the weight limit. This should ensure that we never speculate when we are close to the memory limits.
  2. I have written new code to calculate the number of KV cache blocks. This calculation uses the memory scaling coefficients that we have learned at startup. In particular, it uses to the learned coefficients to figure out what % of the memory capacity needs to be set aside for cache blocks.
  3. In the above calculation, I use the next token coefficient, rather than the prefill coefficient, since typically during next token phase the KV cache blocks comprise a relatively large percentage of the total memory consumption and we need to be able to handle this worst-case. However, this means that during prefill steps, we may not have enough memory leftover to store the auxiliary data structures we need for a forward pass. There isn't really a clean way to handle this other than re-writing the router logic to be block-aware, but what we can do is recommend to the user that they should increase the batch safety margin to a certain level to ensure that prefills will not run OOM. I've added a print statement to provide this guidance.
  4. I now load the speculator before learning the memory scaling model since we also need to take that into account when measuring the amount of free memory.

Result

These changes, together with setting the BATCH_SAFETY_MARGIN=35, seems to result in robust behaviour for both llama3-8b and granite-20b. We no longer need to manually set the number of KV cache blocks in the latter case.

Related Issues

n/a