google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
228 stars 26 forks source link

Create attention layers with external kv cache and prepare an example. #24

Closed haozha111 closed 1 month ago

haozha111 commented 1 month ago

This PR contains following changes: 1) Fork the current attention code into layers/experimental/attention.py, which supports external KV cache style (for GPU work). This will avoid polluting the current internal-kv attention implementation. 2) Put SDPA functions into separate files to allow code reuse. 3) Prepare a toy model w/ 2 transformer blocks, using external kv cache layers.

Note conversion to TFL fails because tfl's dynamic update slice doesn't support int64 index, will work on a fix later.

BUG=b/338453350