Closed ViktorooReps closed 5 days ago
Solved with removing _compile=True
argument from mask creation and 'dynamic=True' from torch.compile
. I don't really see performance improvements in speed (esp in generation, I suspect because of mask recomputations), but I see some minor memory usage improvements which is the main goal of using FlexAttention
for me.
Hi! I am experimenting with a char-level LLM with word masking. For now, I am struggling to make flex attention work past compilation errors even for simple causal masking..
What am I benchmarking:
The model:
What I am expecting to see:
What I actually see:
Can you help me? How do I fix it? Am I not compiling/calling FlexAttention correctly? Thank you so much in advance!
Setup
H100 GPU, CUDA 12.4
pip freeze
:Here is the full script to reproduce:
When I run it with
impl = 'sdpa'
, I get no compilation errors, and here are the results:Here is the output for
impl = flex
: