Open alitirmizi23 opened 9 months ago
cc @awaelchli, if you'd like to answer
Just chiming in, from what I understand, this is not a simple feature to implement in general.
As one current example, the axolotl finetuning harness implements efficient sample packing with correct block diagonal attention masking through a series of monkey patches for the underlying huggingface model definitions for a few of the very popular models like llama and mistral. Though I have not looked through the code in detail, I believe it leverages the fact that the flash attention api supports the masking required to implement this scheme.
It seems like the simplicity of the lit-gpt model definition might actually make this easier to implement as a first class feature. It is relevant for efficient finetuning (the reason it's incorporated into axolotl), and general wisdom (and whispers from inside large corps) suggest that this type of block diagonal masking is better for large scale training code.
I (and other collaborators) would be very interested in this feature and it would increase the attractiveness of lit-gpt's model building code as a hf alternative. Just my 2c!
if a document, article, instruction/output pair exceeds the max sequence length, how is it treated?
Depends on the data preparation, but our scripts trim it: https://github.com/Lightning-AI/lit-gpt/blob/0791c52a944f022a5cee91ed1e47288830efb72c/scripts/prepare_alpaca.py#L116-L117
What about if a doc/article/instruction-output pair falls short of max seq. length? are the remaining time steps padded or are more sequences packed until max length is achieved?
They are packed in the pretraining scripts: https://github.com/Lightning-AI/lit-gpt/blob/0791c52a944f022a5cee91ed1e47288830efb72c/pretrain/redpajama.py#L249-L257 and padded in the fine-tuning scripts: https://github.com/Lightning-AI/lit-gpt/blob/0791c52a944f022a5cee91ed1e47288830efb72c/finetune/full.py#L232-L246
suggest that this type of block diagonal masking is better for large scale training code.
The inconvenience is that torch.nn.functional.scaled_dot_product_attention
will not use flash-attn if an attention mask is passed. It would be necessary to integrate this specific flavor of flash attention: https://github.com/Dao-AILab/flash-attention/issues/654 which would again require building it
We'd also be very interested in this feature!
@carmocca let’s revive this issue it doesn’t look like spda from PyTorch 2.3 has solved the underlying issue, if that’s the case let’s add flash attention as an optional dependency
PyTorch has added support for arbitrary custom masks as long which are meant to be performant when used with torch.compile
: https://github.com/pytorch/pytorch/pull/121845
They are also considering more generic API changes that are in discussion: https://github.com/pytorch/pytorch/issues/110681.
As of today, Tri Dao's package is the only option as far as I know.
I was wondering if there are sample packing approaches defined somewhere for preprocessing and tokenization of datasets? I looked through different prepare_*.py, but couldn't find anything related to packing multiple sequences being packed into max_length for efficiency etc
Also, wondering how the data prep works as of now in the lit-gpt framework: