corl-team / rebased

Official implementation of the paper "Linear Transformers with Learnable Kernel Functions are Better In-Context Models"
Apache License 2.0
157 stars 3 forks source link

Lack of clarity about sim function vs feature map for paper/code #1

Open deklanw opened 8 months ago

deklanw commented 8 months ago

Hi, I read your paper and found the following confusing. When you're describing your ablations which culminate in ReBased it starts with

x^2 – substituting the original kernel function with a simple element-wise squaring operation, ϕ(x) = x^2.

but, this doesn't seem to be what happens in your code. See these lines

https://github.com/corl-team/rebased/blob/7a085b40e0a2c615f35ee43afd19f856df37819e/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py#L68-L69

My understanding of Linear Attention is the following. We need two functions: a similarity function (called sim or s) which takes two vectors and returns a scalar and a feature map (called phi typically) which takes a single vector and returns another vector (possibly of different dimension). Ignoring normalization by 1/sqrt(d) for simplicity, Linear Attention requires that

s(q, k) = dot(phi(q), phi(k))

Those lines of code I linked to correspond to defining

s(q, k) = dot(q, k)**2

The feature map phi which corresponds to this similarity function is not elementwise squaring. I.e., phi(x) = x**2 is not the corresponding feature map for that similarity function. The correct corresponding feature map is phi(x) = flatten(outer(x, x))

One could say similar things about the other variants in the ablations, including ReBased.

Am I missing something?

elephantmipt commented 8 months ago

Hi, thank you for your interest. Sorry for the confusion, by phi=x^2 we use simplified notation. From my perspective in case of sim(q, k) = (q^Tk)^2 we have phi(q) != q^2. However, we still can factorise our similarity into kernels.

Check out tests with the reference implementation:

https://github.com/corl-team/rebased/blob/fc11fa14b28b1a0948a03d69675ac0163b6d75d2/flash_linear_attention/fla/layers/rebased_fast.py#L66-L70

https://github.com/corl-team/rebased/blob/fc11fa14b28b1a0948a03d69675ac0163b6d75d2/flash_linear_attention/fla/layers/rebased_fast.py#L161

deklanw commented 8 months ago

Still confused.

These lines you link to

https://github.com/corl-team/rebased/blob/fc11fa14b28b1a0948a03d69675ac0163b6d75d2/flash_linear_attention/fla/layers/rebased_fast.py#L66-L70

correspond to computing a self outer product and flattening.

To clarify, it seems like you agree that the similarity function is sim(q, k) = dot(q, k)**2 but do you agree with

The feature map phi which corresponds to this similarity function is not elementwise squaring. I.e., phi(x) = x**2 is not the corresponding feature map for that similarity function. The correct corresponding feature map is phi(x) = flatten(outer(x, x))

Agree?

kefirski commented 8 months ago

Yes, you are correct. I guess the source for your misunderstanding is messy with parallel and linear computing models.

Like we stated above, sim(q, k) = (dot(q, k))^2. Now, consider we want to evaluate a model with such similarity in parallel mode (evaluate the similarity between Q and K first and then multiply similarity scores with V). In such a case, you could first evaluate matmul(Q, K^T) and then square scores. The obtained matrix will match the sim applied to each pair of q and k. In such a mode, the kernel is evaluated.

https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py#L69

If we want to factorize the multiplication of Q and K, things become slightly more sophisticated since we have to evaluate some phi(x) with vectors from Q and K such that if we evaluated matmul(phi(Q), phi(K)^T) it will be equal to evaluating it in parallel mode. If we factorize multiplication in such a way, we could multiply phi(K) with V first and then multiply it with phi(Q).

Managing this factorization is easy, since (qk^T)^2 = (q_1 * k_1 + q_2 * k_2 + ... + q_n * k_n)^2. After algebraic messing, we could derive an evaluation order from the reference implementation.

https://github.com/corl-team/rebased/blob/fc11fa14b28b1a0948a03d69675ac0163b6d75d2/flash_linear_attention/fla/layers/rebased_fast.py#L66

Here is a simple snippet that could help you.

import torch

def x_2(x: torch.Tensor):
    # Get 2nd-order terms (rearrange(x * x), '... m n -> ... (m n)')
    x2 = (x.unsqueeze(-1) * x.unsqueeze(-2)
          ).flatten(start_dim=-2)
    return x2  # simple case without the normalization of attention scores

if __name__ == "__main__":
    torch.manual_seed(5)

    q = torch.randn(2, 3, 6)
    k = torch.randn(2, 3, 6)
    v = torch.randn(2, 3, 6)

    q = q.view(2, 3, 1, -1).transpose(1, 2)
    k = k.view(2, 3, 1, -1).transpose(1, 2)
    v = v.view(2, 3, 1, -1).transpose(1, 2)

    qk = torch.einsum("bhqd,bhkd->bhqk", q, k)
    parallel_res = torch.einsum("bhqk,bhkd->bhqd", qk ** 2, v)
    q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)

    linear_factorized_res = ((x_2(q) * (x_2(k) * v).sum(2, True)).sum(-1))
    print(torch.max(parallel_res - linear_factorized_res))  # tensor(1.1444e-05)

I have omitted the normalization of attention scores and mess with causal masking, but the general idea should be clear.

deklanw commented 8 months ago

Yes. I agree. The verification code can be even simpler

import numpy as np

q = np.random.normal(0, 1, 64)
k = np.random.normal(0, 1, 64)

self_outer_flatten = lambda x: np.outer(x, x).flatten()

assert np.allclose(np.dot(q, k)**2, np.dot(self_outer_flatten(q), self_outer_flatten(k)))

Good. We're on the same page there.

Returning to my original post, and the title of this issue: your paper seems to be inconsistent/confused with your code.

x^2 – substituting the original kernel function with a simple element-wise squaring operation, ϕ(x) = x^2.

Your paper seems to me to be conflating a similarity function (also called a kernel function) with a feature map (phi). There is no element-wise squaring anywhere. Your paper doesn't use the term "feature map" or "feature function" or "feature representation" anywhere.

yaraksen commented 8 months ago

Noting that we provided a one-dimensional scenario in the paper for simplicity, we agree that the notations of the kernel function and feature mapping therein can be confusing. We will update it to make it more accurate. Thank you for pointing that out!