vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
21.93k stars 3.1k forks source link

[RFC]: Improve guided decoding (logit_processor) APIs and performance. #5423

Open rkooo567 opened 2 weeks ago

rkooo567 commented 2 weeks ago

Motivation.

Currently, guided decoding & logit processor API is incomplete has has several issues. The RFC is intended to bring up problems and solutions. Some of issues may have been already addressed and there are PRs out already.

There are 3 major issues.

Proposed Change.

API

guided decoding parameters are not supported with SamplingParams. It is addressed from https://github.com/vllm-project/vllm/pull/4130

Performance

Currently, logit processors APIs are applied row by row blocking (https://github.com/vllm-project/vllm/blob/246598a6b1e22616630b7f1bf11bd9bcb31dc860/vllm/model_executor/layers/logits_processor.py#L112). Instead, we can use parallel processing (e.g., ray or thread pool) to improve the logit processing performance. We are using this mechanism internally at Anyscale. We'd like to support this feature in OSS, and would like to improve logit processor API to support 1. async. 2. batching.

This requires logit processor to be

class LogitPostProcessor:
   def initialize(self, logit_processor_config: LogitProcessorConfig):
       """Initialize the post processor. Post processor may have states
           such as thread pool or Ray actors. It should be initialized
           here.
       """
       ...

   def prepare(
           self,
           seq_gruop_metadata_list: List[SequenceGroupMetadata]):
       """Asynchronously prepare logit masks."""
       ...

   def apply(self, logits: torch.Tensor) -> torch.Tensor:
       """Apply the prepared masks to a given logits."""
       ...

# For each model, we will have

def compute_logits(...):
    ....

def prepare_logits(seq_group_metadata_list):
    ....

prepare and apply assume 1:1 calls. E.g., once prepare is called, apply has to be called before another prepare is called. I think it is the safe assumption. Alternatively, we can make prepare return a class, but that will make interface surface larger, so I don't prefer that solution (but I am open to hear feedback!)

This is the example usage of the API

        # each model will have prepare_logits API
        self.model.prepare_logits(seq_group_metadata_list)
        hidden_states = model_executable(
            input_ids=input_tokens,
            positions=input_positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
            **multi_modal_kwargs,
        )
        # Compute the logits. logit processors are applied here.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

We are also considering to upstream Ray based batch processing implementation with lmformatenforcer.

Failure Handling

When using a stateful logit processor, it is possible requests are failed. For example, if we use Ray, Ray actors can die. Or there could be user's schema issue that cannot be caught ahead of time.

When it happens, we should fail the seq_group immediately. We will introduce a new status "FINISHED_INTERNAL_ERROR = enum.auto()" to https://github.com/vllm-project/vllm/blob/246598a6b1e22616630b7f1bf11bd9bcb31dc860/vllm/sequence.py#L42. If any logit processor is failed, we will mark the relevant seq_group as failed, and the request will be aborted.

Feedback Period.

No response

CC List.

cc @simon-mo @Yard1

Any Other Things.

No response

simon-mo commented 2 weeks ago

cc @njhill @br3no @mmoskal

br3no commented 2 weeks ago

I have a few questions:

It is not supported from SamplingParamters

Can you elaborate on why you think placing the guided decoding parameters in the SamplingParams is a good idea? As I commented in #4130, I think they conceptually overlap with the logits processors implementing the guided decoding, which are already in the SamplingParams.

This requires logit processor to be

  • stateful (to use a tool like Ray or thread pool). ...

Do you maybe mean stateless? If not, what do you mean exactly?

Regarding the topic of statefulness: we probably don't want to limit ourselves to stateless logits processors. If we manage to make the API so that it is easy to implement stateful logits processors, we would already make things much better. E.g. I think that a very good thing to address would be to add infrastructure for pooling stateful objects and making it easy to define that one such object should not be shared across sequences and requests, or at least should be reset before being used.

Could you also please elaborate on the new LogitsPostProcessor API you propose? Is this the API to be implemented by logits processors? Or is this an API to be implemented by the models?

Are there maybe some type annotations missing for the return values of e.g. prepare? If this method does not return anything, this means the LogitsPostProcessor is stateful, right? Shouldn't we aim for a stateless design here, to make parallelization easier?

I might have misunderstood the proposal though. So, I'd be really happy if you could elaborate on it.

All in all, I would be very interested in improvements in this area, so I'm glad you're working on it!

rkooo567 commented 2 weeks ago

Can you elaborate on why you think placing the guided decoding parameters in the SamplingParams is a good idea? As I commented in https://github.com/vllm-project/vllm/pull/4130, I think they conceptually overlap with the logits processors implementing the guided decoding, which are already in the SamplingParams.

It's like moving the functionality to the core API. Right now, it is implemented like an add-on (only working with OpenAI server), and it doesn't work with tools like https://github.com/anyscale/ray-llm (because we directly use the core API). It requires code that breaks the abstraction barrier (i.e., creating logit processor), and given the guided decoding is a core function, I feel like having the API in SamplingParams make sense.

Do you maybe mean stateless? If not, what do you mean exactly?

To improve time to prepare masks for json mode, we want to use parallel processing tools such as threadpool or ray. It requires the logit processor to be "stateful" because we don't want to recreate actors or threadpools everytime logit processos is requested (it should be created in __init__).

E.g. I think that a very good thing to address would be to add infrastructure for pooling stateful objects and making it easy to define that one such object should not be shared across sequences and requests, or at least should be reset before being used.

+1. I think it'd be an implementation of part 2.

Could you also please elaborate on the new LogitsPostProcessor API you propose? Is this the API to be implemented by logits processors? Or is this an API to be implemented by the models?

It will replace _apply_logit_processor private API inside logit_processor.py. Right now, we apply logit mask row by row. We instead 1. find the relevant logit processor created. 2. logit_processor.prepare(seq_group_metadata_list) -> logit_processor.apply(logits).

Are there maybe some type annotations missing for the return values of e.g. prepare? If this method does not return anything, this means the LogitsPostProcessor is stateful, right? Shouldn't we aim for a stateless design here, to make parallelization easier?

You are right the prep and apply is stateful. We can make it this way as well.

        masks = self.model.prepare_logits(seq_group_metadata_list)
        hidden_states = model_executable(
            input_ids=input_tokens,
            positions=input_positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
            **multi_modal_kwargs,
        )
        # Compute the logits. logit processors are applied here.
        logits = self.model.compute_logits(hidden_states, sampling_metadata, masks)

But I found it easier to just make it fully stateful.

Hope this clarifies the proposal a little bit!

simon-mo commented 2 weeks ago

We should make this work with the following RFCs

@NadavShmayo https://github.com/vllm-project/vllm/pull/4769 @mmoskal https://github.com/vllm-project/vllm/pull/4775 @mmoskal https://github.com/vllm-project/vllm/pull/2888 @maxdebayser @njhill https://github.com/vllm-project/vllm/pull/5329 @lynkz-matt-psaltis https://github.com/vllm-project/vllm/pull/5006

rkooo567 commented 2 weeks ago

My initial thoughts;

mmoskal commented 2 weeks ago

Some ideas:

With an additional post-sampling callback, this would subsume my SequenceController #4775 :

    def sampled(self, seq: 'Sequence', token_id: int,
                logprobs: Dict[int, 'Logprob']) -> Tuple[int, List[int], bool]:
        """
        Informs the controller a given token has been sampled.
        Returns the number of tokens to backtrack, the tokens to append,
        and whether to stop.
        """
        if token_id == seq.eos_token_id:
            return 0, [], True
        return 0, [token_id], False
rkooo567 commented 2 weeks ago

With an additional post-sampling callback, this would subsume my SequenceController https://github.com/vllm-project/vllm/pull/4775 :

I see. I found that API is limited for our particular use case because as you know it is applied after sampling is done (whereas we want to apply logit processor on final logits). It's great if we can subsume it.

add some sort of free() API so resources can be freed

I am open to it, but right now there's no specific use cases.

maybe initialize() can be async? the reason is that we don't start scheduling sequences, where the processor is still initializing (in case it takes a few seconds)

How is this guaranteed now?

br3no commented 2 weeks ago

@rkooo567 thanks, let me see if I understand it:

The idea is that the logits processors will be asked to prepare their masks asynchronously and in the meantime the model is going to be run. Once both are ready, the logits are computed by having the model call apply.

This means that the whole process needs to guarantee that there is one logits processor instance per request per sequence. Correct?

The implementation will need to be very careful to avoid contention issues.


Regarding the combination of this with the other PRs: I'm still struggling a bit to understand what general design we need. Let me explain:

The logits processors are now applied in the models; so the general signature of the operation is

compute_logits(hidden_states: Tensor, ...) -> Tensor

We want to support ff-tokens or backtracking (e.g. #4775). These things happen a few layers above the model and don't fit this API above.

So we're talking about different things in different abstraction layers at the same time.

Am I the only one? Is the design clear to you folks? If so, I would appreciate it a lot if someone could describe where which types of object would play which role.

mmoskal commented 2 weeks ago

@br3no One thing that took me a while to see is that there is only one LogitPostProcessor per LLMEngine - it handles logits for all sequences in the current batch.

There was some discussion of allowing a list of those, but IMHO it's easy to write a LogitPostProcessor that bundles an arbitrary number of `LogitPostProcessors so I think there's no need to have a list of post processors in vLLM.

I'm the one asking for ff_tokens and backtracking, I think @rkooo567 is not doing this now.

njhill commented 2 weeks ago

@rkooo567 @simon-mo @mmoskal some additional thoughts after we talked offline yesterday:

It's a concern that the current support is kind of broken, it doesn't work for input batches or beam search due to the stateful/concurrency thing. So I wonder if we could prioritize some simpler immediate fixes for that along with the egregious performance overhead with json mode due to having to construct a new CFGuide instance every time. i.e. before the more significant rework to introduce batched application and the prepare step... WDYT?

A couple of other thoughts about the proposed interface:

br3no commented 2 weeks ago

@mmoskal thanks for your answer! I also would like to support ff-tokens since I think this would contribute to alleviate the performance issues.

@njhill I’m not familiar with lm-format-enforcer, but for the Outlines processors now only the CFG one is problematic. The others are now stateless. Should we concentrate on a “fix” for the output_format: json issue? This would involve an object pool for the CFGGuide for that particular use case. Or am I missing other aspects here?

rkooo567 commented 2 weeks ago

There was some discussion of allowing a list of those, but IMHO it's easy to write a LogitPostProcessor that bundles an arbitrary number of `LogitPostProcessors so I think there's no need to have a list of post processors in vLLM.

I also agree with it. I have impression the current interface is a little over-designed with some vague implementation in mind. For ff-tokens and backtracking, I would like to see the implementation otherwise it is very difficult to design the interface (that's why we punted). I think the interface I propose here is not going to bother us getting there (logit processor API also feels like it is not very stable API yet, so we have time to iterate).

It's a concern that the current support is kind of broken, it doesn't work for input batches or beam search due to the stateful/concurrency thing. So I wonder if we could prioritize some simpler immediate fixes for that along with the egregious performance overhead with json mode due to having to construct a new CFGuide instance every time. i.e. before the more significant rework to introduce batched application and the prepare step... WDYT?

Does it mean supporting stateful logit processor first (meaning merging the open PR)? I am okay with this.

Why would we need an initialize method, couldn't a regular constructor be used for this?

I think regular constructor could work. The main reason was we need to pass the decode config to the logit processor, and since it is inside the model, the required change was big. I think constructor makes more sense actually.

I'm not sure that it's a good idea to expose List[SequenceGroupMetadata] in this API ... I had assumed SequenceGroupMetadata is an internal datastructure that we want the freedom to change without breaking 3rd party LogitsProcessor impls. Probably should have some simpler dataclass or abstract class designed specifically for the API.

Yeah it is a good point. for our internal impl, we just need seq_data, seq_ids, request_id, and sampling params.