Open yash-srivastava19 opened 1 week ago
Hi @yash-srivastava19 thanks for this PR, but this is not how we should fix that since ideally we should catch that either by checking that the received type is a torch.dtype
or just ensuring that the str
provided as torch_dtype
via the CLI is not transformed to a torch.dtype
before instantiating the SFTTrainer
for example.
So a more suitable fix should be the following:
model_init_kwargs["torch_dtype"] = (
model_init_kwargs["torch_dtype"]
if model_init_kwargs["torch_dtype"] in ["auto", None]
or isinstance(model_init_kwargs["torch_dtype"], torch.dtype)
else getattr(torch, model_init_kwargs["torch_dtype"])
)
Anyway, I'll let the authors chime in with their thoughts and ideas about a potential fix! Thanks anyway 🤗
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Thanks a lot for this ! I second what @alvarobartt said above, we can change this fix to something like:
diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index e739b2d..80e11ad 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -159,11 +159,13 @@ class SFTTrainer(Trainer): raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.") else: model_init_kwargs = args.model_init_kwargs - model_init_kwargs["torch_dtype"] = ( - model_init_kwargs["torch_dtype"] - if model_init_kwargs["torch_dtype"] in ["auto", None] - else getattr(torch, model_init_kwargs["torch_dtype"]) - ) + torch_dtype = model_init_kwargs["torch_dtype"] + + # Convert to `torch.dtype` if an str is passed + if isinstance(torch_dtype, str) and torch_dtype != "auto": + torch_dtype = getattr(torch, torch_dtype) + + model_init_kwargs["torch_dtype"] = torch_dtype if infinite is not None: warnings.warn(
And it worked fine on my end! Would you be happy to apply these changes instead in this PR?
Yes, it is much more optimal. Agreed
Did the json encoding error rectified as well or it pertains even after the fix?
Thanks ! that's another issue we can fix in a follow up PR !
1751 mentioned that the TRL CLI is not completely capturing the torch_dtype. I thought the issue was urgent, so quickly patched a hacky fix, which at least initiates the SFT Trainer.
Original Issue : On running the following command :
The error was that trl sft is does not identify it as a string when calling
getattr(torch, model_init_kwargs["torch_dtype"])
.The fix was made which allows the to not break the pipeline at this stage. Although it is a hacky fix, I'm willing to work on it further :)
The error after that is from the transformers library that isn't able to serialize the dtype object(screenshot attached):