foundation-model-stack / fms-acceleration

🚀 Collection of libraries used with fms-hf-tuning to accelerate fine-tuning and training of large models.
Apache License 2.0
6 stars 12 forks source link

Expand support for Collator caller functions #88

Closed kmehant closed 2 months ago

kmehant commented 2 months ago

At this point, fms-acceleration patches only torch_call function, however there are standard collators such as DataCollatorForSeq2Seq which do not implement torch_call funciton however use the standard __call__.

see: https://github.com/huggingface/transformers/blob/4d5b45870411053c9c72d24a8e1052e00fe62ad6/src/transformers/data/data_collator.py#L585

We need to update the torch call patch function.

https://github.com/foundation-model-stack/fms-acceleration/blob/3cf092a656c94bac36917f228868403e9e9955b9/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py#L64

I am happy to raise a PR.

fabianlim commented 2 months ago

@kmehant the intent here is only to mimic the recent version of trl.DataCollatorForCompletionOnlyLM which has the remove padding logic

The logic to restrict the two data collators only come from the current implementation of fms-hf-tuning:

So why do you expect to have other other collator?

kmehant commented 2 months ago

Thanks @fabianlim

Makes sense!