Open heli-stand opened 1 month ago
Thanks for reporting @heli-stand. I can't reproduce this error. The following runs.
from trl import AutoModelForCausalLMWithValueHead
ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained("trl-internal-testing/tiny-random-GPT2LMHeadModel", torch_dtype="auto")
ppo_model.save_pretrained("my_ppo_model")
ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained("my_ppo_model", torch_dtype="auto", local_files_only=True)
Can you share the version of the package that you use?
Given: ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained('path/to/my/AutoModelForCausalLM', torch_dtype="auto") ppo_model.save_pretrained("ppo_model")
When I do: ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained("ppo_model", torch_dtype="auto", local_files_only=True)
Error: Traceback (most recent call last): File "", line 1, in
File "/home/.conda/envs/test_env/lib/python3.10/site-packages/trl/models/modeling_base.py", line 324, in from_pretrained
filename = hf_hub_download(
File "/home/.conda/envs/test_env/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py", line 101, in inner_f
return f(*args, **kwargs)
File "/home/.conda/envs/test_env/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 106, in _inner_fn
validate_repo_id(arg_value)
File "/home/.conda/envs/test_env/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 154, in validate_repo_id
raise HFValidationError(
huggingface_hub.errors.HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name'
There is logic error in the function of from_pretrained, it tries to download from hub instead of using the saved.