juncongmoo / pyllama

LLaMA: Open and Efficient Foundation Language Models
GNU General Public License v3.0
2.8k stars 312 forks source link

about rotary embedding in llama #83

Closed irasin closed 10 months ago

irasin commented 1 year ago

I have a question about the rotate_half function for rotary embedding

why use

def rotate_half(x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

instead of using

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., 0::2]
    x2 = x[..., 1::2]
    return torch.stack([-x2, x1], dim=-1).reshape(x.shape)

To do rotary embedding, we need to transform the q from [q0, q1, q2, q3, ...] to [-q1, q0, -q3, q2, ... ], right?

lwang2070 commented 10 months ago

I had the same problem until I found this issue on huggingface. Hope it helps :)

irasin commented 10 months ago

Thanks, @lwang2070, I've figured out the calculation mechanism.