tlc4418 / llm_optimization

A repo for RLHF training and BoN over LLMs, with support for reward model ensembles.
https://arxiv.org/abs/2310.02743
MIT License
28 stars 2 forks source link

Unable to Run PPO Training Using HuggingFace Path of SFT'd language model #10

Open RylanSchaeffer opened 3 months ago

RylanSchaeffer commented 3 months ago

I'm trying to run vanilla PPO against either a single reward model or an ensemble of 5 reward models.

Command: accelerate launch --main_process_port=29503 --config_file configs/accelerate_config.yaml src/ppo/trainer_rl.py --configs defaults defaults_rlhf pythia_44m_rlhf_ensemble_mean

My config is here:

pythia_44m_rlhf_ensemble_mean:
  output_dir: runs/ppo_ensemble
  datasets:
    - alpaca_farm

  gold_config:
    model_name: alpaca_farm_models/reward-model-human
    is_alpacafarm_rm: true
    batch_size: 32

  rank_config:
    is_reward_model: true
    model_names: 
      - models/rm/switching_rms_pythia_rm_44m_sftseed0_seed0

    objective_name: mean # Change objetive (mean, random, WCO, or UWO)
    uwo_weight: 0.1 # Change UWO weight (only for UWO)
    cache_dir: .cache
    pooling: last
    residual_dropout: 0.01
    use_flash_attention: false
    dtype: bf16
    batch_size: 128

  sft_config:
    is_reward_model: false
    model_name: RylanSchaeffer/switching_rms_pythia_sft_1p4b_seed0
    cache_dir: .cache
    quantization: false
    seq2seqmodel: false
    freeze_layer:
    num_layers_unfrozen: 2 
    residual_dropout: 0.2
    use_flash_attention: false
    dtype: bf16
    batch_size: 32

However, the SFT config's model_name throws this error:

[rank0]:     index_file_name = hf_hub_download(
[rank0]:   File "/lfs/ampere8/0/rschaef/miniconda3/envs/reward_modeling_env/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/lfs/ampere8/0/rschaef/miniconda3/envs/reward_modeling_env/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1221, in hf_hub_download
[rank0]:     return _hf_hub_download_to_cache_dir(
[rank0]:   File "/lfs/ampere8/0/rschaef/miniconda3/envs/reward_modeling_env/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1282, in _hf_hub_download_to_cache_dir
[rank0]:     (url_to_download, etag, commit_hash, expected_size, head_call_error) = _get_metadata_or_catch_error(
[rank0]:   File "/lfs/ampere8/0/rschaef/miniconda3/envs/reward_modeling_env/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1722, in _get_metadata_or_catch_error
[rank0]:     metadata = get_hf_file_metadata(url=url, proxies=proxies, timeout=etag_timeout, headers=headers)
[rank0]:   File "/lfs/ampere8/0/rschaef/miniconda3/envs/reward_modeling_env/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/lfs/ampere8/0/rschaef/miniconda3/envs/reward_modeling_env/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1645, in get_hf_file_metadata
[rank0]:     r = _request_wrapper(
[rank0]:   File "/lfs/ampere8/0/rschaef/miniconda3/envs/reward_modeling_env/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 372, in _request_wrapper
[rank0]:     response = _request_wrapper(
[rank0]:   File "/lfs/ampere8/0/rschaef/miniconda3/envs/reward_modeling_env/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 396, in _request_wrapper
[rank0]:     hf_raise_for_status(response)
[rank0]:   File "/lfs/ampere8/0/rschaef/miniconda3/envs/reward_modeling_env/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 315, in hf_raise_for_status
[rank0]:     raise EntryNotFoundError(message, response) from e
[rank0]: huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error. (Request ID: Root=1-6697d66d-22a2a9095f1788b949b35ebc;622b8343-28b0-4609-aeeb-c5fbba1f0c30)

[rank0]: Entry Not Found for url: https://huggingface.co/RylanSchaeffer/switching_rms_pythia_sft_1p4b_seed0/resolve/main/pytorch_model.bin.index.json.

Comparing RylanSchaeffer/switching_rms_pythia_sft_1p4b_seed0 against tlc4418/pythia_1.4b_sft_policy/tree/main, I see that the SFT'd models I created have [model.safetensors](https://huggingface.co/RylanSchaeffer/switching_rms_pythia_sft_1p4b_seed0/blob/main/model.safetensors)

image

whereas your SFT'd models have [pytorch_model.bin](https://huggingface.co/tlc4418/pythia_1.4b_sft_policy/blob/main/pytorch_model.bin):

image

I suspect that something changed in transformers in the intervening time.

I'm going to go open an issue with trlx but can you suggest any workarounds?

Perhaps it would be helpful to specify the exact library versions you used for your experiments :)

RylanSchaeffer commented 3 months ago

Oh, there's already an open issue with trlx here! https://github.com/CarperAI/trlx/issues/580