kimborgen / falcon-llm

Apache License 2.0
1 stars 0 forks source link

Optimize AttentionRotary #12

Open kimborgen opened 1 year ago

kimborgen commented 1 year ago

Because the alibi functionality is removed from AttentionRotary in #11 there may be room for optimization.

QKV Linear Layer and split_heads

The QKV weights have been instansiated in one layer self.query_key_value = Linear( self.hidden_size, 3 * self.hidden_size if not config.multi_query else (self.hidden_size + 2 * self.head_dim), bias=config.bias, )

Which is then called, and then split on every pass

fused_qkv = self.query_key_value(hidden_states)  # [batch_size, seq_length, 3 x hidden_size]

# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)

Is it more efficient to split these Linear layers into Q, K, V on model initiation to reduce the additional computation on each pass?


From ChatGPT: 
Advantages of fusing:
- Parameter Sharing: Using one linear layer might encourage parameter sharing and could lead to some regularizing effects.
- Memory Efficiency: It's efficient in memory since it utilizes in-place operations, which can be important if you're training large models.

Disadvantages:
- Computational Overhead: Splitting QKV in every forward pass can introduce a minor computational overhead.
- Customization: If you want different initializations or separate regularization for Q, K, V, having them in one layer might be cumbersome.

- Memory Usage: By splitting QKV into separate layers, there's potential for memory savings during training. This is because the gradients for each layer can be discarded after they're used, rather than holding onto the combined QKV gradients.
- Flexibility: With separated layers, it becomes easier to modify or replace individual Q, K, or V layers in the future without touching the others.
- Readability and Debugging: It may be easier for other developers to understand and debug your code when the layers are separate.

For LoRA it will enable more granular control. In the falcontune repo, the LoRA adapter is on the entire query_key_value layer. Having one adapter for the entire layer might not be efficient (New Issue).

Conclusion: Leave it as it is as it might be more efficient(?). But keep it an open option in the future.

RotaryEmbeddings reshaping

The forward pass reshapes the k,v,q to apply RotaryEmbeddings and then reshapes it back for F.scaled_dot_product_attention. It might be easier to rather deal with this reshaping in RotaryEmbeddings to avoid unnecesarry reshaping

# Reshape Q, K, V for compatibility with RotaryEmbeddings
batch_size, q_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)

# Apply Rotary embeddings
query_layer, key_layer = self.rotary(query_layer, key_layer)

# If there's a past layer (like in transformer decoding), use it.
if layer_past is not None:
    past_key, past_value = layer_past
    key_layer = torch.cat((past_key, key_layer), dim=1)
    value_layer = torch.cat((past_value, value_layer), dim=1)

# Cache mechanism for Transformer decoding
present = (key_layer, value_layer) if use_cache else None

# Reshape for scaled dot-product attention
query_layer = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
key_layer = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
value_layer = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)