vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.62k stars 3.9k forks source link

[Misc]: Cross-attention QKV computation is inefficient #7397

Open afeldman-nm opened 1 month ago

afeldman-nm commented 1 month ago

This issue is not in response to a performance regression.

The method of performing cross-attention QKV computations introduced in #4942 could be improved. Because this issue relates to cross-attention, it only impacts encoder/decoder models, not decoder-only models.

For context, QKVParallelLinear computes QKV from the previous decoder layer's hidden state output, i.e. only a single input. The problem is that cross attention requires QKV to be computed from two inputs: Q must be computed from the previous decoder layer's hidden state output, and KV must be computed from the encoder's output hidden states. Additionally,

The current, inefficient workaround for cross-attention is to construct a QKVParallelLinear layer & apply it at most 2 times in a given run of the cross-attention forward() method: once to decoder_hidden_states to obtain Q, and (only during prefill) a second time to encoder_hidden_states to obtain KV:

# (afeldman-nm 2024/07/22) TODO:
# Need a more efficient solution for q/k/v
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
                        dim=-1)
if encoder_hidden_states is None:
    k = None
    v = None
else:
    qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
    _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
                            dim=-1)

Cost breakdown of the current method

During prefill,

During decode

Proposed solution

What is needed is a modification or subclass to QKVParallelLinear with the following properties

afeldman-nm commented 1 month ago

FYI #7448 is addressing this issue