vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
25.89k stars 3.78k forks source link

[Bug]: DynamicNTKScalingRotaryEmbedding implementation is different from Transformers #3488

Open killawhale2 opened 5 months ago

killawhale2 commented 5 months ago

Your current environment

The output of `python collect_env.py`

🐛 Describe the bug

There is a difference in the vLLM implementation of DynamicNTKScalingRotaryEmbedding from the Transformer implementation that causes performance degradation for inputs that fit into the original context length.

Specifically, the Transformer implementation initializes the base for the Rope embedding with the original max_position_embeddings and re-computes the base for inputs that surpass the original max_position_embeddings (hence the name dynamic), as detailed below.

# taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L158
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    def forward(self, x, position_ids):
        # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_position_embeddings:
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))
            inv_freq = 1.0 / (
                base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
            )
            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: this may break with compilation

        cos, sin = super().forward(x, position_ids)
        return cos, sin

In contrast, the vLLM implementation initializes the base with max_position_embeddings * scaling_factor once, and does not re-compute it depending on the length of the input, as detailed below. This means that for inputs with lengths smaller than the original context length, the outputs from pre-/post-rope scaling will differ, sometimes by large amounts.

# taken from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py#L219
    def _compute_cos_sin_cache(self) -> torch.Tensor:
        # NOTE(woosuk): self.max_position_embeddings is the original
        # maximum length before applying the rope scaling.
        # Thus, the maximum length after applying the rope scaling is
        # self.max_position_embeddings * self.scaling_factor.
        max_len = self.max_position_embeddings * self.scaling_factor
        base = self.base * (
            (self.scaling_factor * max_len / self.max_position_embeddings) -
            (self.scaling_factor - 1))**(self.rotary_dim /
                                         (self.rotary_dim - 2))
        inv_freq = self._compute_inv_freq(base)
        t = torch.arange(max_len, dtype=torch.float)

        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        return cache

I was wondering if this was an intentional design to save re-computation of the cosine and sine cache. If not, I would love to see the implementation fixed to match the Transformer one, as applying rope scaling only to find that the performance degrades for inputs that fit into the original context length is kind of weird.

aporia3517 commented 5 months ago

+1 @WoosukKwon Can you take a look at this issue?

Junphy-Jan commented 4 months ago

Same problem. Using vllm and ntk scaling cause a worse reuslt compared to transformers

NathanYanJing commented 2 months ago

+1 Is this because of a version error?

Missmiaom commented 1 month ago

+1

asimj1342 commented 1 month ago

+1

zxexz commented 1 month ago

+1