huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
8.61k stars 1.06k forks source link

Issue #1751 Fix #1754

Open yash-srivastava19 opened 1 week ago

yash-srivastava19 commented 1 week ago

1751 mentioned that the TRL CLI is not completely capturing the torch_dtype. I thought the issue was urgent, so quickly patched a hacky fix, which at least initiates the SFT Trainer.

Original Issue : On running the following command :

trl sft --model_name_or_path=facebook/opt-125m --dataset_name=imdb  --dataset_text_field=text --max_steps=1 --torch_dtype=bfloat16 --output_dir=./test

The error was that trl sft is does not identify it as a string when calling getattr(torch, model_init_kwargs["torch_dtype"]).

The fix was made which allows the to not break the pipeline at this stage. Although it is a hacky fix, I'm willing to work on it further :)

The error after that is from the transformers library that isn't able to serialize the dtype object(screenshot attached):

...
...
TypeError: Object of type dtype is not JSON serializable
Traceback (most recent call last):

Screenshot 2024-06-18 171826

alvarobartt commented 1 week ago

Hi @yash-srivastava19 thanks for this PR, but this is not how we should fix that since ideally we should catch that either by checking that the received type is a torch.dtype or just ensuring that the str provided as torch_dtype via the CLI is not transformed to a torch.dtype before instantiating the SFTTrainer for example.

So a more suitable fix should be the following:

model_init_kwargs["torch_dtype"] = (
  model_init_kwargs["torch_dtype"]
  if model_init_kwargs["torch_dtype"] in ["auto", None]
  or isinstance(model_init_kwargs["torch_dtype"], torch.dtype)
  else getattr(torch, model_init_kwargs["torch_dtype"])
)

Anyway, I'll let the authors chime in with their thoughts and ideas about a potential fix! Thanks anyway 🤗

HuggingFaceDocBuilderDev commented 1 week ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

yash-srivastava19 commented 1 week ago

Thanks a lot for this ! I second what @alvarobartt said above, we can change this fix to something like:

diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py
index e739b2d..80e11ad 100644
--- a/trl/trainer/sft_trainer.py
+++ b/trl/trainer/sft_trainer.py
@@ -159,11 +159,13 @@ class SFTTrainer(Trainer):
             raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.")
         else:
             model_init_kwargs = args.model_init_kwargs
-            model_init_kwargs["torch_dtype"] = (
-                model_init_kwargs["torch_dtype"]
-                if model_init_kwargs["torch_dtype"] in ["auto", None]
-                else getattr(torch, model_init_kwargs["torch_dtype"])
-            )
+            torch_dtype = model_init_kwargs["torch_dtype"]
+
+            # Convert to `torch.dtype` if an str is passed
+            if isinstance(torch_dtype, str) and torch_dtype != "auto":
+                torch_dtype = getattr(torch, torch_dtype)
+
+            model_init_kwargs["torch_dtype"] = torch_dtype

         if infinite is not None:
             warnings.warn(

And it worked fine on my end! Would you be happy to apply these changes instead in this PR?

Yes, it is much more optimal. Agreed

yash-srivastava19 commented 1 week ago

Did the json encoding error rectified as well or it pertains even after the fix?

younesbelkada commented 1 week ago

Thanks ! that's another issue we can fix in a follow up PR !