triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.27k stars 1.62k forks source link

Flash attention 2 backward unusable for llama? #2046

Open DachengLi1 opened 1 year ago

DachengLi1 commented 1 year ago

Thanks for the amazing work!

I observed that the tutorial has very slow backward for flash attention2 (hdim=128/llama), e.g. 7x slower. I also looked into similar issues but found no progress (e.g. https://github.com/openai/triton/issues/1975). Does this mean (1) the tutorial is unusable for actual training? (2) if I want to use Triton to train model, the best I can get is the one with flash-attn 1 algorithm from Tri's repo?

Thanks a lot!

Jokeren commented 1 year ago

There are some compiler issues related to the backward pass when pipelining is involved for nested loops. I plan to investigate these issues soon. However, it is possible to circumvent these compiler issues by refactoring the code (bwd pass), which you might need to figure out on your own.

DachengLi1 commented 1 year ago

@Jokeren great to hear that! I think the current version in ops using sequence parallel should only use a single for loop? I used that recently, and roll back the atomic add operation - still seeing 5-6x slowdown compared to flash attention backward. Will that be resolved soon? Thanks!

sbodenstein commented 1 year ago

@Jokeren: have you had a chance to look into the slow backward pass yet? As @DachengLi1 mentioned, the issue is there even with sequence parallelism.

Jokeren commented 1 year ago

@Jokeren: have you had a chance to look into the slow backward pass yet? As @DachengLi1 mentioned, the issue is there even with sequence parallelism.

Yes, I'm starting to; however, I might be slow because I'm working on it part-time.

sbodenstein commented 1 year ago

@Jokeren: I see that there is a large rewrite of the backward kernel. Might that solve the issues? (https://github.com/openai/triton/pull/2332).

Jokeren commented 1 year ago

Yes! Unless you still want to write the backward kernel in nested loops.

@DachengLi1 @sbodenstein