huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.29k stars 1.16k forks source link

Error in AutoModelForCausalLMWithValueHead.load_pretrained("path/to/my/ppo_model") #1917

Open heli-stand opened 1 month ago

heli-stand commented 1 month ago

Given: ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained('path/to/my/AutoModelForCausalLM', torch_dtype="auto") ppo_model.save_pretrained("ppo_model")

When I do: ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained("ppo_model", torch_dtype="auto", local_files_only=True)

Error: Traceback (most recent call last): File "", line 1, in File "/home/.conda/envs/test_env/lib/python3.10/site-packages/trl/models/modeling_base.py", line 324, in from_pretrained filename = hf_hub_download( File "/home/.conda/envs/test_env/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py", line 101, in inner_f return f(*args, **kwargs) File "/home/.conda/envs/test_env/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 106, in _inner_fn validate_repo_id(arg_value) File "/home/.conda/envs/test_env/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 154, in validate_repo_id raise HFValidationError( huggingface_hub.errors.HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name'

There is logic error in the function of from_pretrained, it tries to download from hub instead of using the saved.

qgallouedec commented 2 weeks ago

Thanks for reporting @heli-stand. I can't reproduce this error. The following runs.

from trl import AutoModelForCausalLMWithValueHead

ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained("trl-internal-testing/tiny-random-GPT2LMHeadModel", torch_dtype="auto")
ppo_model.save_pretrained("my_ppo_model")
ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained("my_ppo_model", torch_dtype="auto", local_files_only=True)

Can you share the version of the package that you use?