Closed wang2yn84 closed 1 month ago
We used to insert cache inside attention, then use updated cache for calculation. With the help of flash attention/ragged attention, we can delay the cache insertion to the end of each step. By switching to left aligned stacked cache, we can minimize the data transfer to HBM and therefore improve performance. The decode step time reduced from 52ms to 42ms. The overall benchmark performance is boosted by 15%.
we can delay the cache insertion to the end of each step
15% improvement is a great achievement! I assume the test side use stacked aligned + ragged attention, do you have any performance number with left aligned (without stacked) + ragged attention?
We used to insert cache inside attention, then use updated cache for calculation. With the help of flash attention/ragged attention, we can delay the cache insertion to the end of each step. By switching to left aligned stacked cache, we can minimize the data transfer to HBM and therefore improve performance. The decode step time reduced from 52ms to 42ms. The overall benchmark performance is boosted by 15%.
we can delay the cache insertion to the end of each step
15% improvement is a great achievement! I assume the test side use stacked aligned + ragged attention, do you have any performance number with left aligned (without stacked) + ragged attention?
We used to insert cache inside attention, then use updated cache for calculation. With the help of flash attention/ragged attention, we can delay the cache insertion to the end of each step. By switching to left aligned stacked cache, we can minimize the data transfer to HBM and therefore improve performance. The decode step time reduced from 52ms to 42ms. The overall benchmark performance is boosted by 15%.
we can delay the cache insertion to the end of each step
15% improvement is a great achievement! I assume the test side use stacked aligned + ragged attention, do you have any performance number with left aligned (without stacked) + ragged attention?
When cache is left aligned + unstacked, the data transfer overhead is non neglegible. I tried flash attention, which is 90ms for each step. These overhead has nothing to do with which attention you are using.
Fixed based on your comments, all the unit tests and lint errors. Please let me know if you have any other comment/suggestions. @qihqi @FanhaiLu1
there is some updates on deps/Jetstream is that intentional?
Fixed based on your comments, all the unit tests and lint errors. Please let me know if you have any other comment/suggestions. @qihqi @FanhaiLu1
There are new lint error, can you fix it?
Fixed based on your comments, all the unit tests and lint errors. Please let me know if you have any other comment/suggestions. @qihqi @FanhaiLu1
There are new lint error, can you fix it?
Fixed all the lint issues.
I will remove precompute_ragged_block_indices, clear up the ragged attention impl (e.g. remove the one for the ring buffer) and simplify the flags for non ring buffer case therefore simplify the cache manager in the subsequent PR. Will push this PR first since it's been standing alone for a while.
We used to insert cache inside attention, then use updated cache for calculation. With the help of flash attention/ragged attention, we can delay the cache insertion to the end of each step. By switching to left aligned stacked cache, we can minimize the data transfer to HBM and therefore improve performance. The decode step time reduced from 52ms to 42ms. The left aligned cache also improves the insert efficiency. The overall benchmark performance is boosted by 15%.