Open afeldman-nm opened 3 months ago
FYI #7448 is addressing this issue
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!
This issue is not in response to a performance regression.
The method of performing cross-attention QKV computations introduced in #4942 could be improved. Because this issue relates to cross-attention, it only impacts encoder/decoder models, not decoder-only models.
For context,
QKVParallelLinear
computes QKV from the previous decoder layer's hidden state output, i.e. only a single input. The problem is that cross attention requires QKV to be computed from two inputs: Q must be computed from the previous decoder layer's hidden state output, and KV must be computed from the encoder's output hidden states. Additionally,The current, inefficient workaround for cross-attention is to construct a
QKVParallelLinear
layer & apply it at most 2 times in a given run of the cross-attentionforward()
method: once todecoder_hidden_states
to obtain Q, and (only during prefill) a second time toencoder_hidden_states
to obtain KV:Cost breakdown of the current method
During prefill,
During decode
Proposed solution
What is needed is a modification or subclass to
QKVParallelLinear
with the following propertiesforward()
takes a decoder hidden states argument, and an optional encoder hidden states argumentforward()
always computes $(decoder\ hidden\ states) W_Q$forward()
computes $(encoder\ hidden\ states) [W_K W_V]$ conditionally: only if the encoder hidden states are notNone