sustcsonglin / flash-linear-attention

Efficient implementations of state-of-the-art linear attention models in Pytorch and Triton
MIT License
1.32k stars 68 forks source link

why delta_net so slow in inference ? #61

Closed ching-sui1995 closed 1 month ago

ching-sui1995 commented 1 month ago

Feature request

Can you explain why delta_net is so slow in inference ?

I tried it and it takes a long time for AR generation.

Is the code here correct for inference?

https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/models/delta_net/modeling_delta_net.py

Motivation

It's not a proposal, just a question

Your contribution

no

yzhangcs commented 1 month ago

@ching-sui1995 Hello, could you report exact inference latencies? Check out Fig. 4(c) in https://arxiv.org/pdf/2409.07146. You may compare the delta speed with the models presented in the plot.

ching-sui1995 commented 1 month ago

@yzhangcs what's the relation between delta_net and the paper you posted (GSA) ?

I am talking about delta update rule.

@sustcsonglin can you comment ?

ching-sui1995 commented 1 month ago

@yzhangcs I also read your paper that you posted ([GSA]. Now I have some questions:

(1) In section 4.2.1, you show results on MQAR and others. Why your model is good at recall-intensive tasks ? you don't use delta update rule like delta_net, so why you should have a better performance ?

(2) how can I reproduce your Figure 4.1.2 (a) on MQAR ?

yzhangcs commented 1 month ago

Hi, I just mean that you can follow the receipes in the paper to test your deltarule models, and if the results of inference latencies are similar, then the problem might not be with delta rule. Regading MQAR results, you can follow https://github.com/HazyResearch/prefix-linear-attention/tree/main/synthetic/zoology. In my settings, I kept the state size of GSA same as GLA for fair comparisons, and I believe that GSA benefits from its context-aware queries via 2-pass recurrence.

ching-sui1995 commented 1 month ago

Thanks @yzhangcs. So you mean we don't need delta rule and can only use 2-pass recurrence to solve the issue with poor performance on recall-intensive tasks ?

have you compared against delta_net on these recall-intensive tasks ? I don't see it in the table

yzhangcs commented 1 month ago

@ching-sui1995 It's recommended to use DeltaNet if u r concerned about recall-intensive tasks, as DetNet performs nearly perfectly. Regarding DeltaNet inference, I'm not sure where the problems you met. Could you report its latencies by running https://github.com/sustcsonglin/flash-linear-attention/blob/main/benchmarks/benchmark_generation.py

ching-sui1995 commented 1 month ago

Thank you !

It is resolved !