JunnYu / FLASHQuad_pytorch

FLASHQuad_pytorch
MIT License
65 stars 10 forks source link

关于rope的实现 #3

Closed Doraemonzzz closed 2 years ago

Doraemonzzz commented 2 years ago

非常感谢你的实现,有个问题想咨询下, 在gau.py文件中,rope的实现如下:

def rope(x, dim):
    """RoPE position embedding."""
    shape = x.shape
    if isinstance(dim, int):
        dim = [dim]
    spatial_shape = [shape[i] for i in dim]
    total_len = 1
    for i in spatial_shape:
        total_len *= i
    position = torch.reshape(
        torch.arange(total_len, dtype=x.dtype,
                     device=x.device), spatial_shape
    )
    for i in range(dim[-1] + 1, len(shape) - 1, 1):
        position = position.unsqueeze(-1)
    half_size = shape[-1] // 2
    freq_seq = -torch.arange(half_size, dtype=x.dtype, device=x.device) / float(
        half_size
    )
    inv_freq = 10000 ** freq_seq
    sinusoid = torch.einsum("...,d->...d", position, inv_freq)
    sin = sinusoid.sin()
    cos = sinusoid.cos()
    x1, x2 = torch.chunk(x, 2, dim=-1)

    return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)

而在https://github.com/JunnYu/RoFormer_pytorch/blob/roformer_v2/src/roformer/modeling_roformer.py rope的实现如下:

    def apply_rotary(x, sinusoidal_pos):
        sin, cos = sinusoidal_pos
        x1, x2 = x[..., 0::2], x[..., 1::2]
        # 如果是旋转query key的话,下面这个直接cat就行,因为要进行矩阵乘法,最终会在这个维度求和。(只要保持query和key的最后一个dim的每一个位置对应上就可以)
        # torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
        # 如果是旋转value的话,下面这个stack后再flatten才可以,因为训练好的模型最后一个dim是两两之间交替的。
        return torch.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).flatten(-2, -1)

主要的区别是x1, x2的定义,前者是按照前一半和后一半划分,后者是按奇数项和偶数项划分,前者实现的并不是rope,我测试过两种实现,效果出入较大,不知道是不是我理解有误。

JunnYu commented 2 years ago
Doraemonzzz commented 2 years ago

测试的任务是roberta预训练,使用了两种编码方式后同期loss差1左右;感觉应该没有太大区别,但是测试下来确实不太一样。

JunnYu commented 2 years ago

那我就不清楚了,可以去问问苏神,这个问题

Doraemonzzz commented 2 years ago

ok,谢谢解答