VITA-Group / Ms-PoE

"Found in the Middle: How Language Models Use Long Contexts Better via Plug-and-Play Positional Encoding" Zhenyu Zhang, Runjin Chen, Shiwei Liu, Zhewei Yao, Olatunji Ruwase, Beidi Chen, Xiaoxia Wu, Zhangyang Wang.
MIT License
21 stars 2 forks source link

The error in the implementiation of the MsPoELlamaRotaryEmbedding #5

Open HaozheZhao opened 2 months ago

HaozheZhao commented 2 months ago

The following is your implementation of the MsPoELlamaRotaryEmbedding:

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:,:seq_len].to(dtype=x.dtype),
            self.sin_cached[:,:seq_len].to(dtype=x.dtype),
        )

However due to the x`s shape of the [bs, num_attention_heads, seq_len, head_size], does the right implementation is:

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

, which is also align with the original implementation of the Rope Embedding of LLama.