Closed hp-l33 closed 5 days ago
Dear @hp-l33 ,
What a fantastic news! You implemented exactly what we want!
I will play with this more efficient implementation these days and will update the repo to support it natively.
Dear @hp-l33 ,
I have tried this triton implementation but it seems that there are some issues. I encounter nan when running the LinFusion demo. Specifically, I use the following codes to check whether the outputs of the torch and triton implementations match. It seems that there are some gaps:
z = query @ key.mean(dim=-2, keepdim=True).transpose(-2, -1) + 1e-6
kv = (key.transpose(-2, -1) * (sequence_length**-0.5)) @ (
value * (sequence_length**-0.5)
)
hidden_states_ = query @ kv / z
query, key, value = map(lambda x: rearrange(x, '(b h) l d -> b h l d', h=self.heads), [query, key, value])
hidden_states = linear_attention(query, key, value)
hidden_states = rearrange(hidden_states, 'b h l d -> (b h) l d')
print(torch.max(torch.abs(hidden_states - hidden_states_)))
Can the demo run successfully on your end? If so, could you please share your environment with me so that I can try to get the correct results :)
Dear @hp-l33 ,
I have tried this triton implementation but it seems that there are some issues. I encounter nan when running the LinFusion demo. Specifically, I use the following codes to check whether the outputs of the torch and triton implementations match. It seems that there are some gaps:
z = query @ key.mean(dim=-2, keepdim=True).transpose(-2, -1) + 1e-6 kv = (key.transpose(-2, -1) * (sequence_length**-0.5)) @ ( value * (sequence_length**-0.5) ) hidden_states_ = query @ kv / z query, key, value = map(lambda x: rearrange(x, '(b h) l d -> b h l d', h=self.heads), [query, key, value]) hidden_states = linear_attention(query, key, value) hidden_states = rearrange(hidden_states, 'b h l d -> (b h) l d') print(torch.max(torch.abs(hidden_states - hidden_states_)))
Can the demo run successfully on your end? If so, could you please share your environment with me so that I can try to get the correct results :)
Hi, I tested with a demo yesterday and got the expected results. However, after the latest update today (with autotune support), a gap has appeared between torch and triton forward. I will roll back to the previous version and fix the bug as soon as possible.
My torch version is 2.4, triton is 3.0, and the versions of other necessary libraries are consistent with LinFusion.
Dear @hp-l33 ,
Thanks for the prompt reply! Indeed, the problem has been solved. I have update the news in README and will update codes later. Best wishes for your future development!
Dear @hp-l33 ,
Thanks for the prompt reply! Indeed, the problem has been solved. I have update the news in README and will update codes later. Best wishes for your future development!
Dear @Huage001, Thank you for your kind words and for recognizing my work. Best wishes!
Thank you for your excellent work!
I maintain a library implementing bi-directional linear attention with Triton, which now supports your LinFusion.
Have fun, and best wishes!