Closed kmehant closed 2 months ago
@kmehant the intent here is only to mimic the recent version of trl.DataCollatorForCompletionOnlyLM
which has the remove padding logic
The logic to restrict the two data collators only come from the current implementation of fms-hf-tuning:
DataCollatorForSeq2Seq
. in this case we replace with DataCollatorWithFlattening
DataCollatorForCompletionOnlyLM
. in this case we activate the remove padding logicSo why do you expect to have other other collator?
Thanks @fabianlim
Makes sense!
At this point, fms-acceleration patches only torch_call function, however there are standard collators such as DataCollatorForSeq2Seq which do not implement
torch_call
funciton however use the standard__call__
.see: https://github.com/huggingface/transformers/blob/4d5b45870411053c9c72d24a8e1052e00fe62ad6/src/transformers/data/data_collator.py#L585
We need to update the torch call patch function.
https://github.com/foundation-model-stack/fms-acceleration/blob/3cf092a656c94bac36917f228868403e9e9955b9/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py#L64
I am happy to raise a PR.