pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
470 stars 23 forks source link

[Feature request] End-to-end transformer example with flex attention #42

Open vladkvit opened 2 months ago

vladkvit commented 2 months ago

I'd love to see a clean example of a transformer that integrates flex attention. I haven't found any samples that do this.

For reference, I have a transformer-based model that uses the TransformerEncoderLayer + TransformerEncoder API. I hacked up a copy-pasted TransformerEncoderLayer class to add a few tweaks (with significant predictive improvement) that could instead be rewritten as a score_mod() function. I'm guessing that using flex attention should result in both cleaner and more performant code. I was hoping that this repo could include a best-practice reference.

drisspg commented 1 month ago

Yeah I think this is a great idea and would fit great in the examples folder. I find that nanogpt is still one of the best/cleanest examples of the decoder only transformer arch: https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L29

The actual integration of FlexAttention into this would be very minimal, but I do think it would be a good idea to show the "best practices" way to do this