Huage001 / LinFusion

Official PyTorch and Diffusers Implementation of "LinFusion: 1 GPU, 1 Minute, 16K Image"
Apache License 2.0
247 stars 17 forks source link

Triton implementation #11

Closed hp-l33 closed 5 days ago

hp-l33 commented 1 week ago

Thank you for your excellent work!

I maintain a library implementing bi-directional linear attention with Triton, which now supports your LinFusion.

Have fun, and best wishes!

Huage001 commented 1 week ago

Dear @hp-l33 ,

What a fantastic news! You implemented exactly what we want!

I will play with this more efficient implementation these days and will update the repo to support it natively.

Huage001 commented 6 days ago

Dear @hp-l33 ,

I have tried this triton implementation but it seems that there are some issues. I encounter nan when running the LinFusion demo. Specifically, I use the following codes to check whether the outputs of the torch and triton implementations match. It seems that there are some gaps:

            z = query @ key.mean(dim=-2, keepdim=True).transpose(-2, -1) + 1e-6
            kv = (key.transpose(-2, -1) * (sequence_length**-0.5)) @ (
                value * (sequence_length**-0.5)
            )
            hidden_states_ = query @ kv / z

            query, key, value = map(lambda x: rearrange(x, '(b h) l d -> b h l d', h=self.heads), [query, key, value])

            hidden_states = linear_attention(query, key, value)

            hidden_states = rearrange(hidden_states, 'b h l d -> (b h) l d')

            print(torch.max(torch.abs(hidden_states - hidden_states_)))

Can the demo run successfully on your end? If so, could you please share your environment with me so that I can try to get the correct results :)

hp-l33 commented 6 days ago

Dear @hp-l33 ,

I have tried this triton implementation but it seems that there are some issues. I encounter nan when running the LinFusion demo. Specifically, I use the following codes to check whether the outputs of the torch and triton implementations match. It seems that there are some gaps:

            z = query @ key.mean(dim=-2, keepdim=True).transpose(-2, -1) + 1e-6
            kv = (key.transpose(-2, -1) * (sequence_length**-0.5)) @ (
                value * (sequence_length**-0.5)
            )
            hidden_states_ = query @ kv / z

            query, key, value = map(lambda x: rearrange(x, '(b h) l d -> b h l d', h=self.heads), [query, key, value])

            hidden_states = linear_attention(query, key, value)

            hidden_states = rearrange(hidden_states, 'b h l d -> (b h) l d')

            print(torch.max(torch.abs(hidden_states - hidden_states_)))

Can the demo run successfully on your end? If so, could you please share your environment with me so that I can try to get the correct results :)

Hi, I tested with a demo yesterday and got the expected results. However, after the latest update today (with autotune support), a gap has appeared between torch and triton forward. I will roll back to the previous version and fix the bug as soon as possible.

hp-l33 commented 6 days ago

My torch version is 2.4, triton is 3.0, and the versions of other necessary libraries are consistent with LinFusion.

Huage001 commented 6 days ago

Dear @hp-l33 ,

Thanks for the prompt reply! Indeed, the problem has been solved. I have update the news in README and will update codes later. Best wishes for your future development!

hp-l33 commented 5 days ago

Dear @hp-l33 ,

Thanks for the prompt reply! Indeed, the problem has been solved. I have update the news in README and will update codes later. Best wishes for your future development!

Dear @Huage001, Thank you for your kind words and for recognizing my work. Best wishes!