lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.67k stars 668 forks source link

Possible incorrect creation of Rotary Embeddinigs #56

Closed AndyBarcia closed 10 months ago

AndyBarcia commented 10 months ago

Disclaimer: I don't have any idea about how this codebase works. I was just trying to implement on my own Rotary Embeddings for a personal project, and I was using the class defined in palm.py as a starting point.

The thing is, I'm not sure if the current implementation of Rotary Embeddings is correct. Specifically, I don't think the following line is correct:

x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)

Because for Rotary embeddings we want to swap pair of adjacent elements, and negate the even elements (aka, turn [1,2,3,4,5,6] into [-2,1,-4,3,-6,5]). But the previous code basically swaps the two halves of the tensor, and negates the first one (aka, turns [1,2,3,4,5,6] into [-4,-5,-6,1,2,3]).

Is the code incorrect or is there something I'm missing?

lucidrains commented 10 months ago

it is correct because the rotary embedding was concatted the same way https://github.com/lucidrains/PaLM-rlhf-pytorch/blob/main/palm_rlhf_pytorch/palm.py#L83