This issue is not in response to a performance regression.
The method of performing cross-attention QKV computations introduced in #4942 could be improved. Because this issue relates to cross-attention, it only impacts encoder/decoder models, not decoder-only models.
For context, QKVParallelLinear computes QKV from the previous decoder layer's hidden state output, i.e. only a single input. The problem is that cross attention requires QKV to be computed from two inputs: Q must be computed from the previous decoder layer's hidden state output, and KV must be computed from the encoder's output hidden states. Additionally,
During prefill phase, both Q and KV must be computed
During decode phase, only Q is computed because the encoder sequence is static so there are no new encoder KVs
The current, inefficient workaround for cross-attention is to construct a QKVParallelLinear layer & apply it at most 2 times in a given run of the cross-attention forward() method: once to decoder_hidden_states to obtain Q, and (only during prefill) a second time to encoder_hidden_states to obtain KV:
# (afeldman-nm 2024/07/22) TODO:
# Need a more efficient solution for q/k/v
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
if encoder_hidden_states is None:
k = None
v = None
else:
qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
Cost breakdown of the current method
During prefill,
$(decoder\ hidden\ states) [W_Q W_K W_V]$ is computed in order to obtain Q, so 2 out of 3 GEMMs are unnecessary
$(encoder\ hidden\ states) [W_Q W_K W_V]$ is computed in order to obtain KV, so 1 out of 3 GEMMs are unnecessary
In total, half of GEMMs are unnecessary (50% efficiency)
During decode
$(decoder\ hidden\ states) [W_Q W_K W_V]$ is computed in order to obtain Q, so 2/3 of GEMMs are unnecessary (33% efficiency)
Proposed solution
What is needed is a modification or subclass to QKVParallelLinear with the following properties
Exploits parallelism over multiple GPUs
forward() takes a decoder hidden states argument, and an optional encoder hidden states argument
This issue is not in response to a performance regression.
The method of performing cross-attention QKV computations introduced in #4942 could be improved. Because this issue relates to cross-attention, it only impacts encoder/decoder models, not decoder-only models.
For context,
QKVParallelLinear
computes QKV from the previous decoder layer's hidden state output, i.e. only a single input. The problem is that cross attention requires QKV to be computed from two inputs: Q must be computed from the previous decoder layer's hidden state output, and KV must be computed from the encoder's output hidden states. Additionally,The current, inefficient workaround for cross-attention is to construct a
QKVParallelLinear
layer & apply it at most 2 times in a given run of the cross-attentionforward()
method: once todecoder_hidden_states
to obtain Q, and (only during prefill) a second time toencoder_hidden_states
to obtain KV:Cost breakdown of the current method
During prefill,
During decode
Proposed solution
What is needed is a modification or subclass to
QKVParallelLinear
with the following propertiesforward()
takes a decoder hidden states argument, and an optional encoder hidden states argumentforward()
always computes $(decoder\ hidden\ states) W_Q$forward()
computes $(encoder\ hidden\ states) [W_K W_V]$ conditionally: only if the encoder hidden states are notNone