vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
30.66k stars 4.65k forks source link

[Misc]: Cross-attention QKV computation is inefficient #7397

Open afeldman-nm opened 3 months ago

afeldman-nm commented 3 months ago

This issue is not in response to a performance regression.

The method of performing cross-attention QKV computations introduced in #4942 could be improved. Because this issue relates to cross-attention, it only impacts encoder/decoder models, not decoder-only models.

For context, QKVParallelLinear computes QKV from the previous decoder layer's hidden state output, i.e. only a single input. The problem is that cross attention requires QKV to be computed from two inputs: Q must be computed from the previous decoder layer's hidden state output, and KV must be computed from the encoder's output hidden states. Additionally,

The current, inefficient workaround for cross-attention is to construct a QKVParallelLinear layer & apply it at most 2 times in a given run of the cross-attention forward() method: once to decoder_hidden_states to obtain Q, and (only during prefill) a second time to encoder_hidden_states to obtain KV:

# (afeldman-nm 2024/07/22) TODO:
# Need a more efficient solution for q/k/v
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
                        dim=-1)
if encoder_hidden_states is None:
    k = None
    v = None
else:
    qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
    _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
                            dim=-1)

Cost breakdown of the current method

During prefill,

During decode

Proposed solution

What is needed is a modification or subclass to QKVParallelLinear with the following properties

afeldman-nm commented 3 months ago

FYI #7448 is addressing this issue

github-actions[bot] commented 1 week ago

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!