xiangking / ark-nlp

A private nlp coding package, which quickly implements the SOTA solutions.
Apache License 2.0
311 stars 64 forks source link

RoPE实现细节 #47

Closed Forest-Scorpio closed 2 years ago

Forest-Scorpio commented 2 years ago
# RoPE编码
if self.RoPE:
    pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs)
    # cos_pos = pos[..., 1::2].repeat(1, 1, 2)
    # sin_pos = pos[..., ::2].repeat(1, 1, 2)
    cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1)  # 修改后
    sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1)  # 修改后

大佬你好,你的RoPE在实现上是不是有点问题,按照苏神的博客应该是上面修改后的代码吧

jimme0421 commented 2 years ago
    if self.RoPE:
        pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs)
        cos_pos = pos[..., 1::2].repeat(1, 1, 2)
        sin_pos = pos[..., ::2].repeat(1, 1, 2)
        qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3)

你可以看下qw2的实现,负项和正项是分开的。和博客中的公式(13)只是顺序不一样,但整体的结果是一样。

Forest-Scorpio commented 2 years ago
# RoPE编码
if self.RoPE:
    pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs)
    cos_pos = pos[..., 1::2].repeat(1, 1, 2)
    sin_pos = pos[..., ::2].repeat(1, 1, 2)
    qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3)
    qw2 = torch.reshape(qw2, qw.shape)
    qw = qw * cos_pos + qw2 * sin_pos

reshape之后不是变成了一负一正交替吗,如果把负项和正项分开的话,公式(13)里面左边的qw是不是也要把奇项和偶项分开才能保证各个位置对齐,最后内积的整体结果不变。

jimme0421 commented 2 years ago

经过初步测试,确实存在你说的问题。

感谢你指出的问题,我们会在统一测试后进行修改。

jimme0421 commented 2 years ago

我们会在下个版本会修复这个bug,并在commit再次表示感谢