vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
22.09k stars 3.12k forks source link

[RFC]: Classifier-Free Guidance #5825

Open Vermeille opened 1 week ago

Vermeille commented 1 week ago

Motivation.

I am one of the authors of the paper Stay On Topic with Classifier-Free Guidance ( https://openreview.net/forum?id=RiM3cl9MdK&noteId=s1BXLL1YZD ) who has been nominated as ICML'24 Spotlight Paper. CFG is a sampling technique that allows LLMs to follow the prompt more closely at the cost of two forward passes per token as well as 2 kv caches. CFG brings non trivial improvements overall over standard benchmarks.

I would be extremely interested in having CFG implemented into vLLM. If possible, I would like to get a bit of guidance into the vLLM code base.

Proposed Change.

CFG contrasts the next token logits between two different prompt (a "positive prompt" a, and a "negative prompt" or "unconditional" b)

Here is the pseudo algorithm

while we sample:
    logits_a = log_softmax(model(prompt_a))
    logits_b = log_softmax(model(prompt_b))
    logits = logits_b + cfg_scale * (logits_a - logits_b)
    next_token = sample_from(logits)
    prompt_a.append(next_token)
    prompt_b.append(next_token)

As you can see this needs two concurrent kv-caches for an efficient implementation. I tried looking for how Speculative Decoding was implemented but this was quite complex, more than CFG needs.

Feedback Period.

No response

CC List.

No response

Any Other Things.

I am willing to implement it myself given enough guidance as this looks like a non trivial thing to implement. I think something similar to / reusing bits of Speculative Decoding might be used but the code is non trivial.

Vermeille commented 4 days ago

Up