vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
21.91k stars 3.09k forks source link

[Speculative decoding] [Performance]: Re-enable bonus tokens #4212

Open cadedaniel opened 2 months ago

cadedaniel commented 2 months ago

Proposal to improve performance

In https://github.com/vllm-project/vllm/pull/3951 we disable bonus tokens (token sampled from verifier model assuming all proposal tokens are accepted) because its KV is not generated for the draft model. We can fix this by "prefilling" the KV of bonus tokens in the draft model. Note that for proposal methods not requiring KV (e.g. prompt lookup), we can re-enable bonus tokens and get a speedup there.

The impact of this performance improvement depends on the speculation length. For low K, e.g. 1, where the probability of accepting the single spec token is high (~= how aligned the draft model and target model are on the sequence), it has high impact because accepting 1 token allows us to emit 2 tokens (1 speculative and 1 bonus). Since we disable bonus tokens, we can now only emit 1 token (the accepted speculative one).

For higher K the impact is less as the likelihood of accepting all speculative tokens is exponentially lower.

https://github.com/vllm-project/vllm/blob/323f27b9048713cdbab31995265975842a937167/vllm/model_executor/layers/rejection_sampler.py#L311-L315

sroy745 commented 2 weeks ago

Hi @cadedaniel wondering if anyone working on this currently? If not I would like to look into it. Please let me know.

cadedaniel commented 2 weeks ago

that would be awesome. let's chat more in vllm slack.

cadedaniel commented 2 weeks ago

Discussed with @sroy745 and @LiuXiaoxuanPKU on best approach to fix this

Goal: decide whether to use batch expansion or chunked prefill
Decision: which takes more time when there are bonus tokens

How fast can we do batch expansion? vs
- We can hit <0.5ms or less, because the batch expansion is small (only 1 per sequence, in the worst case)

How fast can we do chunked prefill fwd pass
- Expected to be slower because no CUDA graphs
- Expected to be slower because attention is slower in triton (Lily says it's flash, so not a concern)

How to measure "prefill" time in chunked prefill?
- Measure fwd pass time for small batch sizes, with varying bonus tokens (1..BS). using chunked prefil kernel
    this is the best measurement because it's exactly what we will run
    but it takes time to set up properly

    this measures prefill computation + lack of cuda graphs

- Measure fwd pass time of decode batch, no cuda graph (no chunked prefill). Small batch sizes.
    this ignores any overhead in the fwd pass of chunked prefill
    BUT it captures 80+% of the overhead, which we intuit is due to lack of cuda graphs

    this measures lack of cuda graphs

Worker.execute_model
    - model_runner.execute_model
        # - prepare_inputs
        - fwd pass
        # - sampling
    - return

--- 68m model with TP1. Measure fwd pass time with and without cuda graph.
JackFrame/68m

https://pytorch.org/docs/stable/generated/torch.cuda.Event.html

torch.cuda.Event(enable_timing=False, blocking=False, interprocess=False)

def model_runner.execute_model(...):
    ...
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record() # need to make sure it's on the same stream as model
    outputs = model.fwd_pass(inputs)
    end_event.record() # need to make sure it's on the same stream as model

    end_event.synchronize()
    elapsed_ms = start_event.elapsed_time(end_event)

    sampled = sampler.sample(outputs)
    return sampled
sroy745 commented 2 weeks ago

Hi @cadedaniel and @LiuXiaoxuanPKU

Here is a pr that I used for doing some measurements.

I ran the tests with JackFram/llama-68m with TP 1 on A-100. Without cuda graphs the decode time is ~0.89ms to 0.87ms at batch size 5 and 10 respectively. This is greater than the 0.5ms expected for batch expansion. Given these numbers should we go with batch expansion then?

For batch size 5 Without Cuda Graph prefill time - 1.16 ms decode time - 0.89 ms With Cuda Graph prefill time - 1.04 ms decode time - 0.23 ms

For batch size 10 Without Cuda Graph prefill time - 1.1 ms decode time - 0.87 ms
With Cuda Graph prefill time - 1.0 ms decode time - 0.22 ms

cadedaniel commented 2 weeks ago

Sounds good to me!

sroy745 commented 4 days ago

This PR implements the logic for enabling bonus tokens. For this feature, the SpecDecodeWorker maintains state across multiple forward passes of a sequence to determine if it was assigned a bonus token. If so, it then backfills the KV cache for the penultimate token in the next forward pass. This logic for maintaining state is implemented in the SpecDecodeWorker

In the current implementation, the SpecDecodeWorker maintains a list of the sequence_ids that were assigned bonus tokens in their last forward pass. If the sequence is not assigned a bonus token in its current pass, it is removed from the list if it was there. However, if the generation is terminated for a sequence that was part of this list, it is never removed. Hence, over time, we will accumulate sequence_ids in this list which are no longer active. Therefore, we need a way to remove such sequence_ids from this list.

One way to implement this would be the following:

  1. Add a new method to the ExecutorBase and WorkerBase that can be invoked to communicate to the Executor and through the Executor to the Worker about the termination of a sequence.
  2. Pass a reference to the ExecutorBase to the SequenceGroupOutputProcessor. In the SequenceGroupOutputProcessor, whenever the sequence terminates, it will invoke the method in ExecutorBase to inform about sequence generation termination.

class ExecutorBase(): .... .... def process_terminated_sequences(sequence_ids: List[int]): """ Pass a list of sequence_ids for which generation has been stopped for processing by the Executor. """ return self.driver_worker.process_terminated_sequences(sequence_ids) ... ....

class WorkerBase(): .... .... def process_terminated_sequences(seq_ids: List[int]): """ Pass a list of sequence_ids for which generation has been stopped for processing by the Executor. """ ..... ..... .....

class SequenceGroupOutputProcessor(ABC): def create_output_processor( scheduler_config: SchedulerConfig, detokenizer: Detokenizer, scheduler: Scheduler, seq_counter: Counter, get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], stop_checker: "StopChecker", executor: ExecutorBase ): def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: .... ....

Invoke executor.process_terminated_sequences for seq_ids whose generation has been stopped.

        ....
        ....


@cadedaniel  can you please take a look at this proposal and let me know if this would work?
cadedaniel commented 3 days ago

The proposal looks good. To simplify the interaction between scheduler and worker, we should embed the finished seq ids in the ExecuteModelRequest. This is better than adding a new method because in the future the worker procs could run forever in a loop; it is also better than coupling the OutputProcessor with the worker as the OutputProcessors will live in their own process in the near future.

by the way, the folks implementing Jamba support ran into the exact same issue. See the changes to ExecuteModelRequest in this PR https://github.com/vllm-project/vllm/pull/4115/files.

sroy745 commented 3 days ago

Thanks for the pointer. Since this pr is addressing the same problem I will wait for this pr to be merged.