Closed irasin closed 10 months ago
I have a question about the rotate_half function for rotary embedding
why use
def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1)
instead of using
def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., 0::2] x2 = x[..., 1::2] return torch.stack([-x2, x1], dim=-1).reshape(x.shape)
To do rotary embedding, we need to transform the q from [q0, q1, q2, q3, ...] to [-q1, q0, -q3, q2, ... ], right?
I had the same problem until I found this issue on huggingface. Hope it helps :)
Thanks, @lwang2070, I've figured out the calculation mechanism.
I have a question about the rotate_half function for rotary embedding
why use
instead of using
To do rotary embedding, we need to transform the q from [q0, q1, q2, q3, ...] to [-q1, q0, -q3, q2, ... ], right?