ContextualAI / HALOs

A library with extensible implementations of DPO, KTO, PPO, ORPO, and other human-aware loss functions (HALOs).
https://arxiv.org/abs/2402.01306
Apache License 2.0
712 stars 40 forks source link

Llama 3 compatibility #19

Closed roshansridhar closed 4 months ago

roshansridhar commented 4 months ago

I tried two llama 3 8B models from huggingface by creating a new config/model/.yaml with

name_or_path: meta-llama/Meta-Llama-3-8B and also name_or_path: NousResearch/Meta-Llama-3-8B

I am able to train the SFT model using the command python train.py loss=sft model=llama_3_8b datasets=[shp,hh] exp_name=l38b_sft_0505 mode=train ++cache_dir=./data/models

However I face this error when I train the PPO model using the command python train.py loss=ppo model=llama_3_8b datasets=[shp,hh] exp_name=l38b_ppo_0517 mode=train ++cache_dir=./data/models ++model.load_from=l38b_sft_0505/LATEST/policy.pt

Stacktrace:

Making experiment directory ./data/models/l38b_ppo_0517
no FSDP port specified; using open port for FSDP: 46451
seed: 1
exp_name: l38b_ppo_0517
datasets:
- shp
- hh
mode: train
debug: false
use_fsdp: true
fsdp_port: 46451
wandb:
  enabled: true
  entity: null
  project: l38b_ppo_0517
cache_dir: ./data/models
local_run_dir: ./data/models/l38b_ppo_0517
do_first_eval: true
minimum_log_interval_secs: 1.0
intermediate_checkpoints: false
trainer: BasicTrainer
lr: 5.0e-07
n_epochs: 1
n_examples: null
optimizer: RMSprop
warmup_steps: 150
eval_every: 20000
n_samples: 128
samples_dir: samples/
n_eval_examples: 512
saved_policy: ./data/models/l38b_ppo_0517/LATEST/policy.pt
top_p: 0.95
human_prefix: '

  <|user|>

  '
assistant_prefix: '

  <|assistant|>

  '
human_suffix: ''
assistant_suffix: ''
frac_unique_desirable: 1.0
frac_unique_undesirable: 1.0
model:
  name_or_path: NousResearch/Meta-Llama-3-8B
  tokenizer_name_or_path: null
  load_from: l38b_sft_0505/LATEST/policy.pt
  block_name: LlamaDecoderLayer
  policy_dtype: bfloat16
  fsdp_policy_mp: null
  reference_dtype: bfloat16
  max_grad_norm: 10.0
  v_head_max_grad_norm: 0.1
  max_length: 2048
  max_prompt_length: 1024
  activation_checkpointing: true
  batch_size: 32
  gradient_accumulation_steps: 1
  eval_batch_size: 16
  use_flash_attention: true
loss:
  name: ppo
  ppo_epochs: 1
  cliprange: 0.5
  trainer: PPOTrainer
  dataloader: UnpairedPreferenceDataLoader
  lam: 0.95
  gamma: 0.99
  critic_coef: 0.01
  KL_coef: 0.1
  use_reference_model: true

================================================================================
Writing to aid-nrt-slurm-bm-gpu-b4-8-ad1-005:./data/models/l38b_ppo_0517
================================================================================
building policy
You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 10.43it/s]
Error executing job with overrides: ['loss=ppo', 'model=llama_3_8b', 'datasets=[shp,hh]', 'exp_name=l38b_ppo_0517', 'mode=train', '++cache_dir=./data/models', '++model.load_from=l38b_sft_0505/LATEST/policy.pt', '++wandb.project=l38b_ppo_0517']
Traceback (most recent call last):
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 304, in hf_raise_for_status
    response.raise_for_status()
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/requests/models.py", line 1021, in raise_for_status
    raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/NousResearch/Meta-Llama-3-8B/resolve/main/pytorch_model.bin

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/rosridha/POCs/rlhf/HALOs/models.py", line 95, in from_pretrained
    filename = hf_hub_download(pretrained_model_name_or_path, "pytorch_model.bin")
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 119, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1261, in hf_hub_download
    metadata = get_hf_file_metadata(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 119, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1674, in get_hf_file_metadata
    r = _request_wrapper(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 369, in _request_wrapper
    response = _request_wrapper(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 393, in _request_wrapper
    hf_raise_for_status(response)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 315, in hf_raise_for_status
    raise EntryNotFoundError(message, response) from e
huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error. (Request ID: Root=1-66483dd3-1fff68591f1d7b9b7a1e9d08;04442bd4-7883-4026-8e20-334b328b960d)

Entry Not Found for url: https://huggingface.co/NousResearch/Meta-Llama-3-8B/resolve/main/pytorch_model.bin.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 304, in hf_raise_for_status
    response.raise_for_status()
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/requests/models.py", line 1021, in raise_for_status
    raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/NousResearch/Meta-Llama-3-8B/resolve/main/pytorch_model.bin.index.json

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/rosridha/POCs/rlhf/HALOs/train.py", line 231, in <module>
    main()
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "/home/rosridha/POCs/rlhf/HALOs/train.py", line 132, in main
    policy = model_class.from_pretrained(
  File "/home/rosridha/POCs/rlhf/HALOs/models.py", line 101, in from_pretrained
    index_file_name = hf_hub_download(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 119, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1261, in hf_hub_download
    metadata = get_hf_file_metadata(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 119, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1674, in get_hf_file_metadata
    r = _request_wrapper(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 369, in _request_wrapper
    response = _request_wrapper(
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 393, in _request_wrapper
    hf_raise_for_status(response)
  File "/home/rosridha/POCs/rlhf/HALOs/halos3/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 315, in hf_raise_for_status
    raise EntryNotFoundError(message, response) from e
huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error. (Request ID: Root=1-66483dd3-495ad9e044bc65914d189248;d2584051-b677-439d-9982-cc1f1fd96021)

Entry Not Found for url: https://huggingface.co/NousResearch/Meta-Llama-3-8B/resolve/main/pytorch_model.bin.index.json.