foundation-model-stack / fms-hf-tuning

🚀 Collection of tuning recipes with HuggingFace SFTTrainer and PyTorch FSDP.
Apache License 2.0
22 stars 41 forks source link

Update trl #213

Closed alex-jw-brooks closed 3 months ago

alex-jw-brooks commented 3 months ago

Description of the change

Updates TRL to more recent versions - some changes are needed here because there is a check inside of more recent versions of TRL to see if the train arg type is a TrainingArguments object, which is a naming collision between a custom class we have and the training arguments in Transformers, which are the superclass for the SFT config.

In the future, we should explore renaming this class to not be so confusing, and potentially split it up to avoid passing things not used by the trainer as part of the trainer args, since IMO these are both bad ideas. For now though, I've updated the code to just build an SFT Config out of the class & drop things that the SFT trainer doesn't know about, which is a bit ugly looking, but a contained and non-API breaking solution.

Related issue number

https://github.com/foundation-model-stack/fms-hf-tuning/issues/206

How to verify the PR

Unit tests pass with no arg errors from sft trainer

Was the PR tested