Open twaka opened 3 months ago
Thanks for reporting! @Andrei-Aksionov and I will take a look
Yes, thanks for the report @twaka 👍.
It's indeed could have been done better on my side. Like in pseudocode for SDPA, where if the mask is a boolean, then another path is taken. And overall, the mask creation can be improved. There is no need to recreate it for every single call. Wanted to do it after the main PR, but didn't have time for it.
Btw, do we need an attention mask during an inference? If it's just a batch size of 1, then the mask will be a vector of 1s, isn't it?
@Andrei-Aksionov The mask still needs to be diagonal, otherwise a given token will attend a future token. But that part is already handled via triu / tril in the code. I am actually not sure if I follow at all, I need to think about this more.
But during generation it cannot attend a future token, since it doesn't exist yet 😊
For the context tokens it would though. I.e., in a generation step, the first generated output token would depend on itself and all future tokens. However, since we chop those off during generation and only keep the last token for the next round, it should be fine just like you said. Nevermind my concern.
You concern was correct. I just forgot to mention that it will work only for a batch size of 1 (no need to mask pad tokens) and with kv-cache. With kv-cache we "send" only the last generated token, get his q, k and v vectors, concat those with cached k and v and now this query vector (for a single token) is accessing k and v from all the previous steps.
Update: I'm so stupid and forgot about the prefill stage 🐒. Here we definitely need it. Nonetheless, since a mask provided to SDPA disables flash attention, maybe the mask should be provided only during that stage 🤔.
That kv cache ... I still need to wrap my head around it. I probably should code it from scratch for myself some time to getter grasp on how to best manage it.
But yeah, the mask thing is a good point regarding disabling flash attention. Hm, that's actually quite bad. Just CCing @apaz-cli to be aware of this for batching, because tok/sec will take a hit with batching then.
I don't know what the solution is, maybe compiled FlexAttention?
It should be easy to do.
Here we are doing a prefill stage for a prompt, so we need to provide pos_ids
:
https://github.com/Lightning-AI/litgpt/blob/b042058e239252af9494b548ab00ad836bc68f41/litgpt/generate/base.py#L134-L137
but here, for the next token generation, we can try to pass pos_ids
as None:
https://github.com/Lightning-AI/litgpt/blob/b042058e239252af9494b548ab00ad836bc68f41/litgpt/generate/base.py#L139-L142
If pos_ids
is None, then the mask is also None:
https://github.com/Lightning-AI/litgpt/blob/b042058e239252af9494b548ab00ad836bc68f41/litgpt/model.py#L78-L87
Curious, will it lead to a performance improvement? ... and will it affect compilation
Bug description
When using
litgpt generate
on models with softcapping,build_mask_cache
creates mask astorch.bool
https://github.com/Lightning-AI/litgpt/blob/ef9647cfa7cd73e03b0e29126bfe8b42cae509eb/litgpt/model.py#L465 and then it's added to scores. https://github.com/Lightning-AI/litgpt/blob/ef9647cfa7cd73e03b0e29126bfe8b42cae509eb/litgpt/model.py#L309 Therefore, attention mask is not accounted for and it does +0 or +1 on scores.What operating system are you using?
Unknown
LitGPT Version