iree-org / iree-jax

Apache License 2.0
48 stars 19 forks source link

Add experimental code for custom attention op #81

Open NatashaKnk opened 5 months ago

NatashaKnk commented 5 months ago

This code can (and should) be expanded upon, and would ideally be integrated in the IREE-Jax workflow. At the moment the workflow is broken since Jaxlib has removed ml_program from their Python deps, so I created a new folder for work that doesn't quite fit anywhere. This is followed by an IREE-side PR that lowers this custom op to IREE attention.

At the moment this is a 1:1 mapping to the IREE version of the op, but it can be expanded upon as needed.