microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
3k stars 202 forks source link

RetNet: relative position #49

Closed fkodom closed 1 year ago

fkodom commented 1 year ago

I believe there is a difference in relative position implemented here, and what is described in the paper. The issue I see is in theta_shift and rotate_every_two

def rotate_every_two(x):
    x1 = x[:, :, :, ::2]
    x2 = x[:, :, :, 1::2]
    x = torch.stack((-x2, x1), dim=-1)
    return x.flatten(-2)  # in einsum notation: rearrange(x, '... d j -> ... (d j)')\

# ...

def theta_shift(x, sin, cos):
    return (x * cos) + (rotate_every_two(x) * sin)

You can see here that theta_shift is applied to q and k, which have input shape (bsz, self.num_heads, tgt_len, self.key_dim) (after transpose).

class MultiScaleRetention(nn.Module):
    # ...
    def forward(
        self,
        x,
        rel_pos,
        chunkwise_recurrent=False,
        incremental_state=None
    ):
        bsz, tgt_len, _ = x.size()
        (sin, cos), inner_mask = rel_pos

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        g = self.g_proj(x)

        k *= self.scaling
        q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
        k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)

        qr = theta_shift(q, sin, cos)
        kr = theta_shift(k, sin, cos)

Why does rotate_every_two shuffle elements along the key_dim axis? This is not what was described in the paper (Equations 3, 4)

Screen Shot 2023-08-02 at 12 25 37 PM

Relative position embedding should depend only on the sequence position (m, n) and theta parameters. For that reason, I wonder if rotate_every_two is a bug?

fkodom commented 1 year ago

Update: I see now that rotate_every_two is effectively multiplying by i. (If we view embedding vector of length d as a complex-valued vector of length d // 2, where odd-numbered indices correspond to the imaginary components.)

Still, this does not seem equivalent to what was described in the paper. The same theta_shift operation is applied to both q and k, whereas the paper only performs conjugation on k.

Then, I suppose that theta_shift is the Euler identity:

# Euler identity
e ** (i * x) = cos(x) + i * sin(x)

We can view theta_shift as multiplying complex-valued q with the complex exponential e ** (i * theta)

def theta_shift(x, sin, cos):
    return (x * cos) + (rotate_every_two(x) * sin)

which effectively increments n -> n + 1 in the exponential e ** (i * n * x).

To my current understanding, this is correct when applied to q, and almost correct when applied to the conjugate of k. We should take the complex conjugate of k after applying theta_shift.

def complex_conjugate(x):
    # Very similar to `rotate_every_two` from earlier
    x1 = x[:, :, :, ::2]
    x2 = x[:, :, :, 1::2]
    x = torch.stack((x1, -x2), dim=-1)
    return x.flatten(-2)  # in einsum notation: rearrange(x, '... d j -> ... (d j)')\

class MultiScaleRetention(nn.Module):
    # ...
    def forward(
        self,
        x,
        rel_pos,
        chunkwise_recurrent=False,
        incremental_state=None
    ):
        # ...
        qr = theta_shift(q, sin, cos)
        kr = complex_conjugate(theta_shift(k, sin, cos))

Finally, it seems that retention does not view q or k as complex-valued vectors -- just regular, real-valued embeddings. That explains why methods like MultiScaleRetention.parallel_forward don't account for complex values. (TBH, I'm still a little unclear on why that is, but at least it makes the code match the math.)

sunyt32 commented 1 year ago

$\mathbb{R}^d$ and $\mathbb{C}^{d/2}$ are Isomorphisms, we use real-valued for simplicity. complex_conjugate is not necessary, where $(ke^{-im\theta})^=k^e^{im\theta}$.

fkodom commented 1 year ago

Just realized that I can put LaTeX into these comments. Definitely would have made my original question cleaner.

I see that $\mathbb{C}^{d/2}$ and $\mathbb{R}^d$ are easily interchangeable. Sounds like isomorphism is the technically correct term. 😅

complex_conjugate is not necessary, where $(ke^{-im\theta})^=k^e^{im\theta}$

In that case, shouldn't the expression be $\left(K_m e^{-im\theta}\right)^{\dagger} = K_m^* e^{im\theta}$ ? Just doesn't feel correct to treat $K_m$ as both a real- and complex-valued Tensor within the same expression.

But I guess it's not important. Whether the frequency is positive/negative, it's still a waveform with the same magnitude of frequency. And that probably doesn't affect anything noticeably.

@sunyt32 Thanks for the response! 🙏

bin123apple commented 11 months ago

@fkodom I think this may be due to: $(e^{im\theta})^T = e^{-im\theta}$, if we start from a simple $2 \times 2$ example: $(e^{im\theta})^T$ = $([[a, -b],[b, a]])^T$ = $([[a, b],[-b, a]])$ = $e^{-im\theta}$. For $d \times d$ situation, based on some relative position embedding papers such as Roformer, This conclusion should also hold. So: $o(n) = \sum Q_n(\gamma e^{i\theta})^{n-m}{K_m}^Tv_m $ $= \sum (Q_n\gamma^n e^{in\theta})(\gamma^{-m}e^{-im\theta}{K_m}^T)v_m$
$= \sum (Q_n\gamma^n e^{in\theta})(\gamma^{-m}(e^{im\theta})^T{K_m}^T)v_m $ $= \sum (\gamma^n Q_n e^{in\theta})(\gamma^{-m}(K_m e^{im\theta})^T)v_m $ I think this conclusion is correct and corresponds to the code. But it is obviously not the same as the Eq. (3) in the paper. Then I checked the Eq. (3) again and I think maybe the final form of Eq. (3) should be $\sum (\gamma^n Q_n e^{in\theta})(\gamma^{-m}(K_m)^Te^{-im\theta})v_m$ instead of $\sum (\gamma^n Q_n e^{in\theta})(\gamma^{-m}(K_me^{-im\theta}))^Tv_m$ (Because $(e^{-im\theta}(K_m)^T)$ obviously satisfies the commutative law)?? It would be great if the author @sunyt32 can help to explain whether my understanding is correct or point out where I'm wrong. And by the way, this work is outstanding!

sunyt32 commented 11 months ago

@bin123apple You are right under the $2\times 2$ real number view, which is the same as the implementation of Roformer. Besides, for a complex view, there is also an implementation in LLaMA, where RoPE is added by transforming q, k into complex.

In a nutshell, if you treat $Q, K$ as real matrixs, then you can follow Roformer. If you treat them as complex matrixs, you can follow LLaMA.