google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
33 stars 14 forks source link

Optimize cache update. #151

Closed wang2yn84 closed 1 month ago

wang2yn84 commented 1 month ago

We used to insert cache inside attention, then use updated cache for calculation. With the help of flash attention/ragged attention, we can delay the cache insertion to the end of each step. By switching to left aligned stacked cache, we can minimize the data transfer to HBM and therefore improve performance. The decode step time reduced from 52ms to 42ms. The left aligned cache also improves the insert efficiency. The overall benchmark performance is boosted by 15%.

FanhaiLu1 commented 1 month ago

We used to insert cache inside attention, then use updated cache for calculation. With the help of flash attention/ragged attention, we can delay the cache insertion to the end of each step. By switching to left aligned stacked cache, we can minimize the data transfer to HBM and therefore improve performance. The decode step time reduced from 52ms to 42ms. The overall benchmark performance is boosted by 15%.

we can delay the cache insertion to the end of each step

15% improvement is a great achievement! I assume the test side use stacked aligned + ragged attention, do you have any performance number with left aligned (without stacked) + ragged attention?

wang2yn84 commented 1 month ago

We used to insert cache inside attention, then use updated cache for calculation. With the help of flash attention/ragged attention, we can delay the cache insertion to the end of each step. By switching to left aligned stacked cache, we can minimize the data transfer to HBM and therefore improve performance. The decode step time reduced from 52ms to 42ms. The overall benchmark performance is boosted by 15%.

we can delay the cache insertion to the end of each step

15% improvement is a great achievement! I assume the test side use stacked aligned + ragged attention, do you have any performance number with left aligned (without stacked) + ragged attention?

We used to insert cache inside attention, then use updated cache for calculation. With the help of flash attention/ragged attention, we can delay the cache insertion to the end of each step. By switching to left aligned stacked cache, we can minimize the data transfer to HBM and therefore improve performance. The decode step time reduced from 52ms to 42ms. The overall benchmark performance is boosted by 15%.

we can delay the cache insertion to the end of each step

15% improvement is a great achievement! I assume the test side use stacked aligned + ragged attention, do you have any performance number with left aligned (without stacked) + ragged attention?

When cache is left aligned + unstacked, the data transfer overhead is non neglegible. I tried flash attention, which is 90ms for each step. These overhead has nothing to do with which attention you are using.

wang2yn84 commented 1 month ago

Fixed based on your comments, all the unit tests and lint errors. Please let me know if you have any other comment/suggestions. @qihqi @FanhaiLu1

qihqi commented 1 month ago

there is some updates on deps/Jetstream is that intentional?

FanhaiLu1 commented 1 month ago

Fixed based on your comments, all the unit tests and lint errors. Please let me know if you have any other comment/suggestions. @qihqi @FanhaiLu1

There are new lint error, can you fix it?

wang2yn84 commented 1 month ago

Fixed based on your comments, all the unit tests and lint errors. Please let me know if you have any other comment/suggestions. @qihqi @FanhaiLu1

There are new lint error, can you fix it?

Fixed all the lint issues.

wang2yn84 commented 1 month ago

I will remove precompute_ragged_block_indices, clear up the ragged attention impl (e.g. remove the one for the ring buffer) and simplify the flags for non ring buffer case therefore simplify the cache manager in the subsequent PR. Will push this PR first since it's been standing alone for a while.