Open Conless opened 4 months ago
A trivial way to solve this is to add a limitation $log\ prob\ tokens < max\ seqs$ in the following function in scheduler.py
:
def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
assert num_new_tokens != 0
assert num_new_seqs != 0
return (self.num_batched_tokens + num_new_tokens <= self.token_budget
and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
If this solution is acceptable, I can submit a pull request later.
Agree this is an issue that needs to be fixed.
I don't quite see how log_prob_tokens < max_seqs
is the right solution though ... isn't this a bit too course-grained?
@robertgshaw2-neuralmagic I agree with you. Another solution I came up with, which is more fine-grained, is to add a new argument max_num_logprobs
to EngineArgs
(defaulting to the value of max_num_seqs
). However, I'm concerned that this argument might rarely be used.
What do you think about it?
@robertgshaw2-neuralmagic I agree with you. Another solution I came up with, which is more fine-grained, is to add a new argument
max_num_logprobs
toEngineArgs
(defaulting to the value ofmax_num_seqs
). However, I'm concerned that this argument might rarely be used.What do you think about it?
I think we should have some user controlled max_num_logprobs
with a sensible default. Let me ask the rest of the group
Then we will need to:
@robertgshaw2-neuralmagic Thank you for considering this.
Then we will need to:
- update profiling logic to take this into account
- update scheduler logic to take this into account
This modification should not be too difficult. Given the concise structure of the current code, it seems feasible to implement by adding the logic to the generation of prompts in profile_run
and to can_schedule
in the scheduler.
Just ran into this problem : (
@Conless would you be willing to tackle the support for this? I agree it would be a nice improvement - I have run into this when performing LLM evaluations on MMLU, this requires a lot of logprobs.
@mgoin No problem, I'd be delighted to tackle it.
Had the same problem :( Are there any quick fixes to this?
Your current environment
🐛 Describe the bug
I encountered an unexpected
CUDA out of memory
error while adding a new feature for LoRA into vLLM. After experimenting with different settings, I discovered that the bug only appears whenprompt_logprobs
inSamplingParams
is set to a non-zero value and a long prompt length is used, as mentioned in #1532. I then tried to locate the bug and found the following (some unimportant tracebacks are omitted):The bug is caused by calculations in
_get_bin_counts_and_mask
during the sampling phase. Whenprompt_logprobs
is enabled, the log probabilities of all tokens in the prompt (up to 8192 for Llama 3, which I am using) are calculated, leading to a memory usage of up to$$num\ tokens \times vocab\ size \times 4 \text{Bytes} = 8192 \times 128256 \times 4 \text{Bytes} = 7.8 \text{GiB}$$
However, this memory usage is not predicted in
profile_run()
, where the sampling parameters are set as:This only considers the calculation of log probabilities for up to 256 tokens (the maximum batched sequence count).