lucidrains / rotary-embedding-torch

Implementation of Rotary Embeddings, from the Roformer paper, in Pytorch
MIT License
568 stars 44 forks source link

Slower than absolute positional embeddings? #36

Open umarbutler opened 1 month ago

umarbutler commented 1 month ago

Hi @lucidrains, Thanks for creating this wonderful package as well as x-transformers. I wanted to understand why rotary embeddings seem to be slower for me than absolute positional embeddings. I'm working with a BERT-like model and I have benchmarked absolute positional embeddings against rotary embeddings with a batch of 64 exactly 512-token long sequences and I have found the absolute positional embeddings to be faster. Using line profiler, I can see that most of the time (>50%) is spent on the line (offset + seq_len) <= self.cached_freqs_seq_len.item() in RotaryEmbedding.forward() (https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py#L290).

In terms of how I am using it, see the snippet below (with certain irrelevant code omitted):

# Copyright 2024 Umar Butler. All rights reserved.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
class SelfAttention(nn.Module):
    def __init__(self, config):
        ...
        self.rotary_emb = RotaryEmbedding(
            dim = self.attention_head_size // 2,
            freqs_for = 'lang',
        )
        ...

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
        rotator = self.rotary_emb.rotate_queries_or_keys

        # Normalize the hidden states.
        query = self.transpose_for_scores(self.query(hidden_states))
        key = self.transpose_for_scores(self.key(hidden_states))
        value = self.transpose_for_scores(self.value(hidden_states))

        # Rotate the query and key vectors.
        query = rotator(query)
        key = rotator(key)

        # Use the appropriate accelerated attention implementation.
        attention, attention_probs = self.attend(
            query = query,
            key = key,
            value = value,
            attention_mask = attention_mask,
            head_mask = head_mask,
            dropout_prob = self.dropout_prob if self.training else 0.0,
            output_attentions = output_attentions,
        )

        attention = attention.view(attention.size()[:-2] + (self.all_head_size,))

        return (attention,) if not output_attentions else (attention, attention_probs)

Would you happen to have any idea as to what could be causing this? Is this expected behaviour? I'm not sure if it could be that 512-tokens is not enough to realise the benefits of rotary embeddings? I do intend on training it with more than that but I want to be sure it will be more performant before I do so.

umarbutler commented 1 month ago

After having another look, it seems like most of the is actually spent in 'Apply rotary embeddings without modifying t in place', particularly, rotate_half().

umarbutler commented 1 month ago

I have profiled my code properly with Austin and it does seem like the majority of the time is spent in rotate_half(), particularly the line x = torch.stack((-x2, x1), dim = -1).

umarbutler commented 1 month ago

In the end, I managed to significantly reduce the time spent on rotary embeddings by caching cos and sin values, cutting out lots of code that was unnecessary for my specific use case (encoder-only), using torch.jit.script to speed up rotate_half() and compling with PyTorch Inductor. Even still, rotation takes up a lot of time, a bit more than absolute positional embeddings now, but when you think of it as enabling you to train on shorter sequences and then do inference with larger sequences, it begins to make much more sense in terms of computation expenses.

VarunGumma commented 3 weeks ago

@umarbutler can you share your implementation?