eric-mitchell / direct-preference-optimization

Reference implementation for DPO (Direct Preference Optimization)
Apache License 2.0
2.18k stars 180 forks source link

When trying to reproduce the complete example, "NotImplementedError: offload_to_cpu=True and NO_SHARD is not supported yet" is thrown #91

Open ZSvedic opened 6 days ago

ZSvedic commented 6 days ago

I tried to reproduce the complete example on a Hyperstack cloud machine (A100-80G-PCIe, OS Image Ubuntu Server 22.04 LTS, R535 CUDA 12.2). Since using a single A100, I reduced the batch size, this command starts the training: python -u train.py model=pythia28 datasets=[hh] loss=sft exp_name=anthropic_dpo_pythia28 gradient_accumulation_steps=2 batch_size=16 eval_batch_size=16 trainer=FSDPTrainer sample_during_eval=false model.fsdp_policy_mp=bfloat16

Unfortunately, training fails when saving at the first checkpoint at 20000 examples with stack trace:

Error executing job with overrides: ['model=pythia28', 'datasets=[hh]', 'loss=sft', 'exp_name=anthropic_dpo_pythia28', 'gradient_accumulation_steps=2', 'batch_size=16', 'eval_batch_size=16', 'trainer=FSDPTrainer', 'sample_during_eval=false', 'model.fsdp_policy_mp=bfloat16']
Traceback (most recent call last):
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/train.py", line 111, in main
    mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, reference_model), join=True)
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/.venv-20230622/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 239, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/.venv-20230622/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
    while not context.join():
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/.venv-20230622/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/.venv-20230622/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/train.py", line 44, in worker_main
    trainer.train()
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/trainers.py", line 352, in train
    self.save(output_dir, mean_eval_metrics)
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/trainers.py", line 501, in save
    policy_state_dict = self.policy.state_dict()
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/.venv-20230622/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1815, in state_dict
    self._save_to_state_dict(destination, prefix, keep_vars)
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/.venv-20230622/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1722, in _save_to_state_dict
    hook(self, prefix, keep_vars)
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/.venv-20230622/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/.venv-20230622/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 669, in _pre_state_dict_hook
    _pre_state_dict_hook_fn[fsdp_state._state_dict_type](
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/.venv-20230622/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 271, in _full_pre_state_dict_hook
    _common_unshard_pre_state_dict_hook(
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/.venv-20230622/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 143, in _common_unshard_pre_state_dict_hook
    _enter_unshard_params_ctx(
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/.venv-20230622/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 109, in _enter_unshard_params_ctx
    fsdp_state._unshard_params_ctx[module].__enter__()
  File "/usr/lib/python3.10/contextlib.py", line 135, in __enter__
    return next(self.gen)
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/.venv-20230622/lib/python3.10/site-packages/torch/distributed/fsdp/_unshard_param_utils.py", line 171, in _unshard_fsdp_state_params
    _validate_unshard_params_args(
  File "/home/ubuntu/dpo-examples/direct-preference-optimization/.venv-20230622/lib/python3.10/site-packages/torch/distributed/fsdp/_unshard_param_utils.py", line 140, in _validate_unshard_params_args
    raise NotImplementedError(
NotImplementedError: offload_to_cpu=True and NO_SHARD is not supported yet

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

It seems that either I installed the incompatible version of a library, or an incompatible library came with the preinstalled cloud image? What is funny is that I tried to install only py dependencies available by 2023-06-22, which was the date of the last commit on requirements.txt, using pypi-timemachine, but it seems I still failed somewhere. Here are the versions on my cloud machine:

accelerate==0.20.3
aiohttp==3.8.4
aiosignal==1.3.1
antlr4-python3-runtime==4.9.3
appdirs==1.4.4
asttokens==2.2.1
async-timeout==4.0.2
attrs==23.1.0
backcall==0.2.0
beautifulsoup4==4.12.2
certifi==2023.5.7
charset-normalizer==3.1.0
click==8.1.3
cmake==3.26.4
comm==0.1.3
datasets==2.12.0
debugpy==1.6.7
decorator==5.1.1
dill==0.3.6
docker-pycreds==0.4.0
executing==1.2.0
filelock==3.12.2
frozenlist==1.3.3
fsspec==2023.6.0
gitdb==4.0.10
GitPython==3.1.31
huggingface-hub==0.15.1
hydra-core==1.3.2
idna==3.4
ipykernel==6.23.1
ipython==8.14.0
jedi==0.18.2
Jinja2==3.1.2
jupyter_client==8.2.0
jupyter_core==5.3.1
lit==16.0.6
MarkupSafe==2.1.3
matplotlib-inline==0.1.6
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.14
nest-asyncio==1.5.6
networkx==3.1
numpy==1.24.3
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
omegaconf==2.3.0
packaging==23.1
pandas==2.0.2
parso==0.8.3
pathtools==0.1.2
peft==0.3.0
pexpect==4.8.0
pickleshare==0.7.5
pip==22.0.2
platformdirs==3.7.0
prompt-toolkit==3.0.38
protobuf==4.23.3
psutil==5.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==12.0.1
Pygments==2.15.1
python-dateutil==2.8.2
pytz==2023.3
PyYAML==6.0
pyzmq==25.1.0
regex==2023.6.3
requests==2.31.0
responses==0.18.0
sentry-sdk==1.25.1
setproctitle==1.3.2
setuptools==59.6.0
six==1.16.0
smmap==5.0.0
soupsieve==2.4.1
stack-data==0.6.2
sympy==1.12
tensor-parallel==1.2.4
tokenizers==0.13.3
torch==2.0.1
tornado==6.3.2
tqdm==4.65.0
traitlets==5.9.0
transformers==4.29.2
triton==2.0.0
typing_extensions==4.6.3
tzdata==2023.3
urllib3==2.0.3
wandb==0.15.3
wcwidth==0.2.6
wheel==0.40.0
xxhash==3.2.0
yarl==1.9.2

Does any of you good souls know what is wrong and which library and version is causing a problem?