Open deklanw opened 8 months ago
Hi, thank you for your interest. Sorry for the confusion, by phi=x^2 we use simplified notation. From my perspective in case of sim(q, k) = (q^Tk)^2 we have phi(q) != q^2. However, we still can factorise our similarity into kernels.
Check out tests with the reference implementation:
Still confused.
These lines you link to
correspond to computing a self outer product and flattening.
To clarify, it seems like you agree that the similarity function is sim(q, k) = dot(q, k)**2
but do you agree with
The feature map phi which corresponds to this similarity function is not elementwise squaring. I.e., phi(x) = x**2 is not the corresponding feature map for that similarity function. The correct corresponding feature map is phi(x) = flatten(outer(x, x))
Agree?
Yes, you are correct. I guess the source for your misunderstanding is messy with parallel and linear computing models.
Like we stated above, sim(q, k) = (dot(q, k))^2
. Now, consider we want to evaluate a model with such similarity in parallel mode (evaluate the similarity between Q
and K
first and then multiply similarity scores with V
). In such a case, you could first evaluate matmul(Q, K^T)
and then square scores. The obtained matrix will match the sim
applied to each pair of q
and k
. In such a mode, the kernel is evaluated.
If we want to factorize the multiplication of Q
and K
, things become slightly more sophisticated since we have to evaluate some phi(x)
with vectors from Q
and K
such that if we evaluated matmul(phi(Q), phi(K)^T)
it will be equal to evaluating it in parallel mode. If we factorize multiplication in such a way, we could multiply phi(K)
with V
first and then multiply it with phi(Q)
.
Managing this factorization is easy, since (qk^T)^2 = (q_1 * k_1 + q_2 * k_2 + ... + q_n * k_n)^2
. After algebraic messing, we could derive an evaluation order from the reference implementation.
Here is a simple snippet that could help you.
import torch
def x_2(x: torch.Tensor):
# Get 2nd-order terms (rearrange(x * x), '... m n -> ... (m n)')
x2 = (x.unsqueeze(-1) * x.unsqueeze(-2)
).flatten(start_dim=-2)
return x2 # simple case without the normalization of attention scores
if __name__ == "__main__":
torch.manual_seed(5)
q = torch.randn(2, 3, 6)
k = torch.randn(2, 3, 6)
v = torch.randn(2, 3, 6)
q = q.view(2, 3, 1, -1).transpose(1, 2)
k = k.view(2, 3, 1, -1).transpose(1, 2)
v = v.view(2, 3, 1, -1).transpose(1, 2)
qk = torch.einsum("bhqd,bhkd->bhqk", q, k)
parallel_res = torch.einsum("bhqk,bhkd->bhqd", qk ** 2, v)
q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
linear_factorized_res = ((x_2(q) * (x_2(k) * v).sum(2, True)).sum(-1))
print(torch.max(parallel_res - linear_factorized_res)) # tensor(1.1444e-05)
I have omitted the normalization of attention scores and mess with causal masking, but the general idea should be clear.
Yes. I agree. The verification code can be even simpler
import numpy as np
q = np.random.normal(0, 1, 64)
k = np.random.normal(0, 1, 64)
self_outer_flatten = lambda x: np.outer(x, x).flatten()
assert np.allclose(np.dot(q, k)**2, np.dot(self_outer_flatten(q), self_outer_flatten(k)))
Good. We're on the same page there.
Returning to my original post, and the title of this issue: your paper seems to be inconsistent/confused with your code.
x^2 – substituting the original kernel function with a simple element-wise squaring operation, ϕ(x) = x^2.
Your paper seems to me to be conflating a similarity function (also called a kernel function) with a feature map (phi). There is no element-wise squaring anywhere. Your paper doesn't use the term "feature map" or "feature function" or "feature representation" anywhere.
Noting that we provided a one-dimensional scenario in the paper for simplicity, we agree that the notations of the kernel function and feature mapping therein can be confusing. We will update it to make it more accurate. Thank you for pointing that out!
Hi, I read your paper and found the following confusing. When you're describing your ablations which culminate in ReBased it starts with
but, this doesn't seem to be what happens in your code. See these lines
https://github.com/corl-team/rebased/blob/7a085b40e0a2c615f35ee43afd19f856df37819e/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py#L68-L69
My understanding of Linear Attention is the following. We need two functions: a similarity function (called
sim
ors
) which takes two vectors and returns a scalar and a feature map (calledphi
typically) which takes a single vector and returns another vector (possibly of different dimension). Ignoring normalization by1/sqrt(d)
for simplicity, Linear Attention requires thats(q, k) = dot(phi(q), phi(k))
Those lines of code I linked to correspond to defining
s(q, k) = dot(q, k)**2
The feature map
phi
which corresponds to this similarity function is not elementwise squaring. I.e.,phi(x) = x**2
is not the corresponding feature map for that similarity function. The correct corresponding feature map isphi(x) = flatten(outer(x, x))
One could say similar things about the other variants in the ablations, including ReBased.
Am I missing something?