Lightning-AI / litgpt

20+ high-performance LLMs with recipes to pretrain, finetune and deploy at scale.
https://lightning.ai
Apache License 2.0
10.81k stars 1.08k forks source link

Sample packing for pretraining/fine-tuning #620

Open alitirmizi23 opened 1 year ago

alitirmizi23 commented 1 year ago

I was wondering if there are sample packing approaches defined somewhere for preprocessing and tokenization of datasets? I looked through different prepare_*.py, but couldn't find anything related to packing multiple sequences being packed into max_length for efficiency etc

Also, wondering how the data prep works as of now in the lit-gpt framework:

carmocca commented 1 year ago

cc @awaelchli, if you'd like to answer

jwkirchenbauer commented 12 months ago

Just chiming in, from what I understand, this is not a simple feature to implement in general.

As one current example, the axolotl finetuning harness implements efficient sample packing with correct block diagonal attention masking through a series of monkey patches for the underlying huggingface model definitions for a few of the very popular models like llama and mistral. Though I have not looked through the code in detail, I believe it leverages the fact that the flash attention api supports the masking required to implement this scheme.

It seems like the simplicity of the lit-gpt model definition might actually make this easier to implement as a first class feature. It is relevant for efficient finetuning (the reason it's incorporated into axolotl), and general wisdom (and whispers from inside large corps) suggest that this type of block diagonal masking is better for large scale training code.

I (and other collaborators) would be very interested in this feature and it would increase the attractiveness of lit-gpt's model building code as a hf alternative. Just my 2c!

carmocca commented 10 months ago

if a document, article, instruction/output pair exceeds the max sequence length, how is it treated?

Depends on the data preparation, but our scripts trim it: https://github.com/Lightning-AI/lit-gpt/blob/0791c52a944f022a5cee91ed1e47288830efb72c/scripts/prepare_alpaca.py#L116-L117

What about if a doc/article/instruction-output pair falls short of max seq. length? are the remaining time steps padded or are more sequences packed until max length is achieved?

They are packed in the pretraining scripts: https://github.com/Lightning-AI/lit-gpt/blob/0791c52a944f022a5cee91ed1e47288830efb72c/pretrain/redpajama.py#L249-L257 and padded in the fine-tuning scripts: https://github.com/Lightning-AI/lit-gpt/blob/0791c52a944f022a5cee91ed1e47288830efb72c/finetune/full.py#L232-L246

suggest that this type of block diagonal masking is better for large scale training code.

The inconvenience is that torch.nn.functional.scaled_dot_product_attention will not use flash-attn if an attention mask is passed. It would be necessary to integrate this specific flavor of flash attention: https://github.com/Dao-AILab/flash-attention/issues/654 which would again require building it

vgoklani commented 10 months ago

https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1698610752

corbt commented 7 months ago

We'd also be very interested in this feature!

lantiga commented 7 months ago

@carmocca let’s revive this issue it doesn’t look like spda from PyTorch 2.3 has solved the underlying issue, if that’s the case let’s add flash attention as an optional dependency

carmocca commented 7 months ago

PyTorch has added support for arbitrary custom masks as long which are meant to be performant when used with torch.compile: https://github.com/pytorch/pytorch/pull/121845

They are also considering more generic API changes that are in discussion: https://github.com/pytorch/pytorch/issues/110681.

As of today, Tri Dao's package is the only option as far as I know.

samsja commented 4 months ago

PyTorch has added support for arbitrary custom masks as long which are meant to be performant when used with torch.compile: pytorch/pytorch#121845

They are also considering more generic API changes that are in discussion: pytorch/pytorch#110681.

As of today, Tri Dao's package is the only option as far as I know.

I think that xformers is doing it as well

carmocca commented 4 months ago

@samsja CUDNN attention is most likely the best option today (see flash attention 3 paper figures) that supports attention masks. xformers is not as competitive on H100s at least.

samsja commented 4 months ago

oh I see, any chance litgpt will integrate some of this option at this point ?

By any chance do you have some benchmark comparing fa2/fa3/xformers/torch sdpa ?

rasbt commented 4 months ago

That's a good question. We don't have a benchmark but LitGPT already supports FlashAttention-2 via PyTorch's SDPA. The plan is to also support FlashAttention-3 (#1578)

samsja commented 4 months ago

That's a good question. We don't have a benchmark but LitGPT already supports FlashAttention-2 via PyTorch's SDPA. The plan is to also support FlashAttention-3 (#1578)

unfortunately torch sdpa cannot leverage flash attention with custom masks (for context stuffing), contrary to xformers and the original flash attention implementation. Bit of a blocker. I am currently implementing it using xformers and litgpt

ali-issa99 commented 1 month ago

Hello @samsja any updates on how we can apply data packing with masking to prevent data from different context to be used when computing attention?

samsja commented 1 month ago

Hello @samsja any updates on how we can apply data packing with masking to prevent data from different context to be used when computing attention?

I end up using flash_attn_varlen_func from flash attention. It's probably a good idea to try flex_attention package as well

ali-issa99 commented 1 month ago

@samsja and you combined it with the code provided by LitGPT? if yes, do u have any code example?

samsja commented 1 month ago

@samsja and you combined it with the code provided by LitGPT? if yes, do u have any code example?

I am re using some litgpt code indeed, but had to change the attention layer as well as the dataloading. Unfortunalty cannot share code for now