pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
484 stars 24 forks source link

Question about FlexAttention for Tabular Data #41

Closed RaphaelMouravieff closed 2 months ago

RaphaelMouravieff commented 2 months ago

Hello,

I’m currently using FlexAttention for tabular data and have created a custom mask that captures the structure of the table, allowing attention to be calculated based on both the table’s structure and content. In my use case, each observation in the dataset represents a different tabular structure, which results in each attention mask having a different shape.

I would like to know if it is possible to use FlexAttention with a batch size greater than 1, given that the mask shapes vary across observations. Additionally, do you think using FlexAttention in this context is optimal, or would another approach be more appropriate for handling these varying table structures?

Thank you for your insights!

Best regards, Raphaël Mouravieff