Lightning-AI / litgpt

20+ high-performance LLMs with recipes to pretrain, finetune and deploy at scale.
https://lightning.ai
Apache License 2.0
9.77k stars 974 forks source link

attention mask is incorrect when generate with softcapping #1672

Open twaka opened 1 month ago

twaka commented 1 month ago

Bug description

When using litgpt generate on models with softcapping, build_mask_cache creates mask as torch.bool https://github.com/Lightning-AI/litgpt/blob/ef9647cfa7cd73e03b0e29126bfe8b42cae509eb/litgpt/model.py#L465 and then it's added to scores. https://github.com/Lightning-AI/litgpt/blob/ef9647cfa7cd73e03b0e29126bfe8b42cae509eb/litgpt/model.py#L309 Therefore, attention mask is not accounted for and it does +0 or +1 on scores.

What operating system are you using?

Unknown

LitGPT Version

rasbt commented 1 month ago

Thanks for reporting! @Andrei-Aksionov and I will take a look

Andrei-Aksionov commented 1 month ago

Yes, thanks for the report @twaka 👍.

It's indeed could have been done better on my side. Like in pseudocode for SDPA, where if the mask is a boolean, then another path is taken. And overall, the mask creation can be improved. There is no need to recreate it for every single call. Wanted to do it after the main PR, but didn't have time for it.

Btw, do we need an attention mask during an inference? If it's just a batch size of 1, then the mask will be a vector of 1s, isn't it?

rasbt commented 1 month ago

@Andrei-Aksionov The mask still needs to be diagonal, otherwise a given token will attend a future token. But that part is already handled via triu / tril in the code. I am actually not sure if I follow at all, I need to think about this more.

Andrei-Aksionov commented 1 month ago

But during generation it cannot attend a future token, since it doesn't exist yet 😊

rasbt commented 1 month ago

For the context tokens it would though. I.e., in a generation step, the first generated output token would depend on itself and all future tokens. However, since we chop those off during generation and only keep the last token for the next round, it should be fine just like you said. Nevermind my concern.

Andrei-Aksionov commented 1 month ago

You concern was correct. I just forgot to mention that it will work only for a batch size of 1 (no need to mask pad tokens) and with kv-cache. With kv-cache we "send" only the last generated token, get his q, k and v vectors, concat those with cached k and v and now this query vector (for a single token) is accessing k and v from all the previous steps.

Update: I'm so stupid and forgot about the prefill stage 🐒. Here we definitely need it. Nonetheless, since a mask provided to SDPA disables flash attention, maybe the mask should be provided only during that stage 🤔.

rasbt commented 1 month ago

That kv cache ... I still need to wrap my head around it. I probably should code it from scratch for myself some time to getter grasp on how to best manage it.

But yeah, the mask thing is a good point regarding disabling flash attention. Hm, that's actually quite bad. Just CCing @apaz-cli to be aware of this for batching, because tok/sec will take a hit with batching then.

I don't know what the solution is, maybe compiled FlexAttention?

Andrei-Aksionov commented 1 month ago

It should be easy to do. Here we are doing a prefill stage for a prompt, so we need to provide pos_ids: https://github.com/Lightning-AI/litgpt/blob/b042058e239252af9494b548ab00ad836bc68f41/litgpt/generate/base.py#L134-L137

but here, for the next token generation, we can try to pass pos_ids as None: https://github.com/Lightning-AI/litgpt/blob/b042058e239252af9494b548ab00ad836bc68f41/litgpt/generate/base.py#L139-L142

If pos_ids is None, then the mask is also None: https://github.com/Lightning-AI/litgpt/blob/b042058e239252af9494b548ab00ad836bc68f41/litgpt/model.py#L78-L87

Curious, will it lead to a performance improvement? ... and will it affect compilation