Open cadedaniel opened 2 months ago
Hi @cadedaniel wondering if anyone working on this currently? If not I would like to look into it. Please let me know.
that would be awesome. let's chat more in vllm slack.
Discussed with @sroy745 and @LiuXiaoxuanPKU on best approach to fix this
Goal: decide whether to use batch expansion or chunked prefill
Decision: which takes more time when there are bonus tokens
How fast can we do batch expansion? vs
- We can hit <0.5ms or less, because the batch expansion is small (only 1 per sequence, in the worst case)
How fast can we do chunked prefill fwd pass
- Expected to be slower because no CUDA graphs
- Expected to be slower because attention is slower in triton (Lily says it's flash, so not a concern)
How to measure "prefill" time in chunked prefill?
- Measure fwd pass time for small batch sizes, with varying bonus tokens (1..BS). using chunked prefil kernel
this is the best measurement because it's exactly what we will run
but it takes time to set up properly
this measures prefill computation + lack of cuda graphs
- Measure fwd pass time of decode batch, no cuda graph (no chunked prefill). Small batch sizes.
this ignores any overhead in the fwd pass of chunked prefill
BUT it captures 80+% of the overhead, which we intuit is due to lack of cuda graphs
this measures lack of cuda graphs
Worker.execute_model
- model_runner.execute_model
# - prepare_inputs
- fwd pass
# - sampling
- return
--- 68m model with TP1. Measure fwd pass time with and without cuda graph.
JackFrame/68m
https://pytorch.org/docs/stable/generated/torch.cuda.Event.html
torch.cuda.Event(enable_timing=False, blocking=False, interprocess=False)
def model_runner.execute_model(...):
...
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record() # need to make sure it's on the same stream as model
outputs = model.fwd_pass(inputs)
end_event.record() # need to make sure it's on the same stream as model
end_event.synchronize()
elapsed_ms = start_event.elapsed_time(end_event)
sampled = sampler.sample(outputs)
return sampled
Hi @cadedaniel and @LiuXiaoxuanPKU
Here is a pr that I used for doing some measurements.
I ran the tests with JackFram/llama-68m with TP 1 on A-100. Without cuda graphs the decode time is ~0.89ms to 0.87ms at batch size 5 and 10 respectively. This is greater than the 0.5ms expected for batch expansion. Given these numbers should we go with batch expansion then?
For batch size 5 Without Cuda Graph prefill time - 1.16 ms decode time - 0.89 ms With Cuda Graph prefill time - 1.04 ms decode time - 0.23 ms
For batch size 10
Without Cuda Graph
prefill time - 1.1 ms
decode time - 0.87 ms
With Cuda Graph
prefill time - 1.0 ms
decode time - 0.22 ms
Sounds good to me!
This PR implements the logic for enabling bonus tokens. For this feature, the SpecDecodeWorker maintains state across multiple forward passes of a sequence to determine if it was assigned a bonus token. If so, it then backfills the KV cache for the penultimate token in the next forward pass. This logic for maintaining state is implemented in the SpecDecodeWorker
In the current implementation, the SpecDecodeWorker maintains a list of the sequence_ids that were assigned bonus tokens in their last forward pass. If the sequence is not assigned a bonus token in its current pass, it is removed from the list if it was there. However, if the generation is terminated for a sequence that was part of this list, it is never removed. Hence, over time, we will accumulate sequence_ids in this list which are no longer active. Therefore, we need a way to remove such sequence_ids from this list.
One way to implement this would be the following:
class ExecutorBase(): .... .... def process_terminated_sequences(sequence_ids: List[int]): """ Pass a list of sequence_ids for which generation has been stopped for processing by the Executor. """ return self.driver_worker.process_terminated_sequences(sequence_ids) ... ....
class WorkerBase(): .... .... def process_terminated_sequences(seq_ids: List[int]): """ Pass a list of sequence_ids for which generation has been stopped for processing by the Executor. """ ..... ..... .....
class SequenceGroupOutputProcessor(ABC): def create_output_processor( scheduler_config: SchedulerConfig, detokenizer: Detokenizer, scheduler: Scheduler, seq_counter: Counter, get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], stop_checker: "StopChecker", executor: ExecutorBase ): def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: .... ....
....
....
@cadedaniel can you please take a look at this proposal and let me know if this would work?
The proposal looks good. To simplify the interaction between scheduler and worker, we should embed the finished seq ids in the ExecuteModelRequest. This is better than adding a new method because in the future the worker procs could run forever in a loop; it is also better than coupling the OutputProcessor with the worker as the OutputProcessors will live in their own process in the near future.
by the way, the folks implementing Jamba support ran into the exact same issue. See the changes to ExecuteModelRequest in this PR https://github.com/vllm-project/vllm/pull/4115/files.
Thanks for the pointer. Since this pr is addressing the same problem I will wait for this pr to be merged.
Proposal to improve performance
In https://github.com/vllm-project/vllm/pull/3951 we disable bonus tokens (token sampled from verifier model assuming all proposal tokens are accepted) because its KV is not generated for the draft model. We can fix this by "prefilling" the KV of bonus tokens in the draft model. Note that for proposal methods not requiring KV (e.g. prompt lookup), we can re-enable bonus tokens and get a speedup there.
The impact of this performance improvement depends on the speculation length. For low K, e.g. 1, where the probability of accepting the single spec token is high (~= how aligned the draft model and target model are on the sequence), it has high impact because accepting 1 token allows us to emit 2 tokens (1 speculative and 1 bonus). Since we disable bonus tokens, we can now only emit 1 token (the accepted speculative one).
For higher K the impact is less as the likelihood of accepting all speculative tokens is exponentially lower.
https://github.com/vllm-project/vllm/blob/323f27b9048713cdbab31995265975842a937167/vllm/model_executor/layers/rejection_sampler.py#L311-L315