vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
22.21k stars 3.13k forks source link

[RFC]: Priority Scheduling #6077

Open apatke opened 2 days ago

apatke commented 2 days ago

Motivation.

vLLM supports first-come-first-served scheduling based on the arrival time of requests.

A prioritization mechanism that enables certain requests to be given higher preference in scheduling is useful because it can enable:

  1. Batch and interactive requests co-location: Batch and interactive requests can be served on the same vLLM instance. If interactive requests arrive while batch requests are executing, they preempt the batch requests for immediate execution. Once the interactive requests are completed, the batch request can resume. If the KV cache for batch requests is preserved, disruption to the overall throughput can be minimized.

  2. Maintain fairness between multiple interactive requests: Recent papers (such as Andes, VTC, and QLM) have proposed mechanisms to maintain fairness between multiple executing requests (for e.g. by maintaining inter-token latency). Most of these mechanisms can be implemented by dynamically changing priority of requests in the waiting and running queue of vLLM.

Proposed Change.

vLLM already has a pluggable scheduling policy class to implement priority scheduling. Hence, the overall change is relatively minimal. These are the major changes required:

  1. Introduction of a priority field in the sequence group: This priority can be static or dynamic based on the specific scheduling policy.

  2. Waiting queue sorting based on priority: Currently, only the running queue is sorted based on the scheduling policy, and the waiting queue is ordered by FCFS. The waiting queue ordering can be made dependent on the policy.

  3. Jointly sorting the waiting and running queue: While both waiting and running queue can be sorted independently, there can still exist priority inversions between the two. Therefore, they also need to be sorted jointly. Sorting the two queues jointly is not possible without forcefully preempting our requests from the running queue and replacing them with requests from the waiting queue.

  4. Enabling forced preemption while preserving the KV cache: Recomputation after forced preemption can lead to repeated computation and KV cache can be preserved to prevent this repeated computation. KV cache swapping can piggyback on the existing implementation of KV cache swapping (implemented in _preempt_by_swap).

  5. Dynamically changing priorities: To maintain fairness between executing requests, their priorities can be dynamically adjusted to prevent starvation and maintain inter-token latency. For example, the Andes paper adjusts request priorities based on estimated Quality of Experience gain (i.e. most starved request with high preempted time gets higher priority). Similarly other policies can be implemented within the generic priority interface.

PR #5958 implements 1,2, and 3. In this PR, priority scheduling can be disabled if not required or overhead is unacceptable.

The PR also adds a priority field to generate function in LLM engine. In future, this could be replaced by a more general scheduling metadata structure to allow for other policy-specific variables. Any suggestions/comments on this metadata structure would be helpful.

Feedback Period.

No response

CC List.

@njhill @saurabhjha1

Any Other Things.

No response

simon-mo commented 2 days ago

I'm curious to learn more about the details of user facing and configuration API

apatke commented 2 days ago

Currently, we are thinking of passing priority as a single optional variable in the generate function of LLM engine and specify the policy (like fcfs, sp, etc.) in scheduler config options. The priority variable could be replaced in future with a more general data class if required to support other policies. Open to any other suggestions.

w013nad commented 10 hours ago

Something I would like is the ability to set a lower max concurrent requests for batch vs API requests. I would like to maintain a high t/s for users while allowing for high numbers of users if necessary.