google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
33 stars 14 forks source link

Fix the performance regression with ragged attention on for llama2 7b. #172

Closed wang2yn84 closed 3 weeks ago

wang2yn84 commented 3 weeks ago

In this PR we did the following:

  1. We split the attention calculation in the generate step into 2 parts, the self attention of the existing KV cache and the newly calculated cache. For the 2nd part, since the cache length is 1, we should not use ragged attention because of performance. That accounts for ~15% of the attention calculation time.
  2. Replace global attention out calculation with more numerical stable way.
  3. Replace args with kwargs when possible to avoid potential issues in the Attention class.
wang2yn84 commented 3 weeks ago

Can you add a short description about this PR? Did this PR fix the performance issue of ragged attention? I also saw some repeat kv change in this PR.

That's because it's based on the another PR. After that one is pushed, I'll rebase and it'll be more clear.

wang2yn84 commented 3 weeks ago

Can you add a short description about this PR? Did this PR fix the performance issue of ragged attention? I also saw some repeat kv change in this PR.

Rebased.