google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
33 stars 14 forks source link

Replace repeat kv with proper GQA handling. #171

Closed wang2yn84 closed 3 weeks ago

wang2yn84 commented 3 weeks ago

Repeat kv in the original llama model will copy the data in some cases. Replace it with reshaping the number of heads dimension in the query to the number of tokens dimension (-2).

lsy323 commented 3 weeks ago

LGTM! Thank you for the change. Seems linter needs to be fixed

wang2yn84 commented 3 weeks ago

Smart change! So the q k^t = [hkv, rep seq_len, seq_len] and q k v = [hkv, rep * seq_len, d], you reshape the output to: [h, seq_len, d] in the end.

Correct. The reshape doesn't affect the result.

wang2yn84 commented 3 weeks ago

LGTM! Thank you for the change. Seems linter needs to be fixed Yup, fixed!