AI-Hypercomputer / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
41 stars 15 forks source link

Replace repeat kv with proper GQA handling. #171

Closed wang2yn84 closed 3 months ago

wang2yn84 commented 3 months ago

Repeat kv in the original llama model will copy the data in some cases. Replace it with reshaping the number of heads dimension in the query to the number of tokens dimension (-2).

lsy323 commented 3 months ago

LGTM! Thank you for the change. Seems linter needs to be fixed

wang2yn84 commented 3 months ago

Smart change! So the q k^t = [hkv, rep seq_len, seq_len] and q k v = [hkv, rep * seq_len, d], you reshape the output to: [h, seq_len, d] in the end.

Correct. The reshape doesn't affect the result.

wang2yn84 commented 3 months ago

LGTM! Thank you for the change. Seems linter needs to be fixed Yup, fixed!