Open alitirmizi23 opened 1 year ago
cc @awaelchli, if you'd like to answer
Just chiming in, from what I understand, this is not a simple feature to implement in general.
As one current example, the axolotl finetuning harness implements efficient sample packing with correct block diagonal attention masking through a series of monkey patches for the underlying huggingface model definitions for a few of the very popular models like llama and mistral. Though I have not looked through the code in detail, I believe it leverages the fact that the flash attention api supports the masking required to implement this scheme.
It seems like the simplicity of the lit-gpt model definition might actually make this easier to implement as a first class feature. It is relevant for efficient finetuning (the reason it's incorporated into axolotl), and general wisdom (and whispers from inside large corps) suggest that this type of block diagonal masking is better for large scale training code.
I (and other collaborators) would be very interested in this feature and it would increase the attractiveness of lit-gpt's model building code as a hf alternative. Just my 2c!
if a document, article, instruction/output pair exceeds the max sequence length, how is it treated?
Depends on the data preparation, but our scripts trim it: https://github.com/Lightning-AI/lit-gpt/blob/0791c52a944f022a5cee91ed1e47288830efb72c/scripts/prepare_alpaca.py#L116-L117
What about if a doc/article/instruction-output pair falls short of max seq. length? are the remaining time steps padded or are more sequences packed until max length is achieved?
They are packed in the pretraining scripts: https://github.com/Lightning-AI/lit-gpt/blob/0791c52a944f022a5cee91ed1e47288830efb72c/pretrain/redpajama.py#L249-L257 and padded in the fine-tuning scripts: https://github.com/Lightning-AI/lit-gpt/blob/0791c52a944f022a5cee91ed1e47288830efb72c/finetune/full.py#L232-L246
suggest that this type of block diagonal masking is better for large scale training code.
The inconvenience is that torch.nn.functional.scaled_dot_product_attention
will not use flash-attn if an attention mask is passed. It would be necessary to integrate this specific flavor of flash attention: https://github.com/Dao-AILab/flash-attention/issues/654 which would again require building it
We'd also be very interested in this feature!
@carmocca let’s revive this issue it doesn’t look like spda from PyTorch 2.3 has solved the underlying issue, if that’s the case let’s add flash attention as an optional dependency
PyTorch has added support for arbitrary custom masks as long which are meant to be performant when used with torch.compile
: https://github.com/pytorch/pytorch/pull/121845
They are also considering more generic API changes that are in discussion: https://github.com/pytorch/pytorch/issues/110681.
As of today, Tri Dao's package is the only option as far as I know.
PyTorch has added support for arbitrary custom masks as long which are meant to be performant when used with
torch.compile
: pytorch/pytorch#121845They are also considering more generic API changes that are in discussion: pytorch/pytorch#110681.
As of today, Tri Dao's package is the only option as far as I know.
I think that xformers is doing it as well
@samsja CUDNN attention is most likely the best option today (see flash attention 3 paper figures) that supports attention masks. xformers is not as competitive on H100s at least.
oh I see, any chance litgpt will integrate some of this option at this point ?
By any chance do you have some benchmark comparing fa2/fa3/xformers/torch sdpa ?
That's a good question. We don't have a benchmark but LitGPT already supports FlashAttention-2 via PyTorch's SDPA. The plan is to also support FlashAttention-3 (#1578)
That's a good question. We don't have a benchmark but LitGPT already supports FlashAttention-2 via PyTorch's SDPA. The plan is to also support FlashAttention-3 (#1578)
unfortunately torch sdpa cannot leverage flash attention with custom masks (for context stuffing), contrary to xformers and the original flash attention implementation. Bit of a blocker. I am currently implementing it using xformers and litgpt
Hello @samsja any updates on how we can apply data packing with masking to prevent data from different context to be used when computing attention?
Hello @samsja any updates on how we can apply data packing with masking to prevent data from different context to be used when computing attention?
I end up using flash_attn_varlen_func
from flash attention. It's probably a good idea to try flex_attention
package as well
@samsja and you combined it with the code provided by LitGPT? if yes, do u have any code example?
@samsja and you combined it with the code provided by LitGPT? if yes, do u have any code example?
I am re using some litgpt code indeed, but had to change the attention layer as well as the dataloading. Unfortunalty cannot share code for now
I was wondering if there are sample packing approaches defined somewhere for preprocessing and tokenization of datasets? I looked through different prepare_*.py, but couldn't find anything related to packing multiple sequences being packed into max_length for efficiency etc
Also, wondering how the data prep works as of now in the lit-gpt framework: