huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.28k stars 1.16k forks source link

push_to_hub from local model #1873

Closed mmg10 closed 6 days ago

mmg10 commented 1 month ago

System Info

Who can help?

@muellerzr @SunMarc

Information

Tasks

Reproduction

In https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py, I pass the following arguments.

    --model_name_or_path /home/ubuntu/work/Meta-Llama-3.1-8B \
    --dataset_name="HuggingFaceH4/no_robots" \
    --output_dir /home/ubuntu/work/Meta-Llama-3.1-8B-SFT \
    --report_to="wandb" \
    --push_to_hub true \
    --push_to_hub_model_id "llama3.1-8b-sft" \

I understand that I am using the SFTTrainer, but it is using super().push_to_hub method, hence I created the issue in this library.

The error is:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/hf_api.py", line 3761, in create_commit
    hf_raise_for_status(response)
  File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/utils/_errors.py", line 358, in hf_raise_for_status
    raise BadRequestError(message, response=response) from e
huggingface_hub.utils._errors.BadRequestError:  (Request ID: Root=1-66a1dfab-4f8aaadd66a3afc164a238e4;a70493b5-d4ba-4022-978f-9947dcf083c0)

Bad request:
"base_model" with value "/home/ubuntu/work/Meta-Llama-3.1-8B" is not valid. Use a model id from https://hf.co/models.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ubuntu/work/sft.py", line 151, in <module>
    trainer.save_model(training_args.output_dir)
  File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 3458, in save_model
    self.push_to_hub(commit_message="Model save")
  File "/opt/conda/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 475, in push_to_hub
    return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 4349, in push_to_hub
    return upload_folder(
           ^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/hf_api.py", line 1398, in _inner
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/hf_api.py", line 4857, in upload_folder
    commit_info = self.create_commit(
                  ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/hf_api.py", line 1398, in _inner
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/hf_api.py", line 3765, in create_commit
    raise ValueError(f"Invalid metadata in README.md.\n{message}") from e
ValueError: Invalid metadata in README.md.
- "base_model" with value "/home/ubuntu/work/Meta-Llama-3.1-8B" is not valid. Use a model id from https://hf.co/models.
wandb: | 0.037 MB of 0.037 MB uploadedded

Expected behavior

The wandb logging and training completes sucessfully. But the push fails. This is because I am loading a model from the local directory /home/ubuntu/work/Meta-Llama-3.1-8B and not from an HF repo. However, the push should work fine!

LysandreJik commented 1 month ago

I don't think there is support for push_to_hub_model_id in that script in sft.py.

Moving your issue to TRL, cc @kashif maybe :)

github-actions[bot] commented 3 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

qgallouedec commented 6 days ago

Use hub_model_id instead, see the TrainingArgumentsdoc

(max_steps 100 for demo)

python examples/scripts/sft.py \
    --model_name_or_path facebook/opt-350m \
    --dataset_name timdettmers/openassistant-guanaco \
    --output_dir opt-350m-sft \
    --dataset_text_field text \
    --push_to_hub \
    --max_steps 100 \
    --hub_model_id my_hub_model_id

https://huggingface.co/qgallouedec/my_hub_model_id