eric-mitchell / direct-preference-optimization

Reference implementation for DPO (Direct Preference Optimization)
Apache License 2.0
2.03k stars 165 forks source link

error when training dpo on multiple cards using FSDPTrainer #14

Closed skepsun closed 1 year ago

skepsun commented 1 year ago

I got error: bad value(s) in fds_to_keep. Full logs:

building policy
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:17<00:00,  8.58s/it]
building reference model
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:13<00:00,  7.00s/it]
loading pre-trained weights at step 19808 from .cache/chuxiong/anthropic_sft_baichuan_2023-07-10_18-47-58_047333/LATEST/policy.pt with metrics {}
loaded pre-trained weights
starting 8 processes for FSDP training
Error executing job with overrides: ['model=baichuan', 'datasets=[rsrt]', 'loss=dpo', 'loss.beta=0.1', 'exp_name=anthropic_dpo_baichuan', 'gradient_accumulation_steps=2', 'batch_size=32', 'eval_batch_size=16', 'trainer=FSDPTrainer', 'sample_during_eval=false', 'model.fsdp_policy_mp=bfloat16', 'model.archive=.cache/chuxiong/anthropic_sft_baichuan_2023-07-10_18-47-58_047333/LATEST/policy.pt']
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /d2/data/chuxiong/direct-preference-optimization/train.py:111 in <module>                        │
│                                                                                                  │
│   108                                                                                            │
│   109                                                                                            │
│   110 if __name__ == '__main__':                                                                 │
│ ❱ 111 │   main()                                                                                 │
│   112                                                                                            │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/site-packages/hydra/main.py:94 in decorated_main          │
│                                                                                                  │
│    91 │   │   │   │   else:                                                                      │
│    92 │   │   │   │   │   # no return value from run_hydra() as it may sometime actually run t   │
│    93 │   │   │   │   │   # multiple times (--multirun)                                          │
│ ❱  94 │   │   │   │   │   _run_hydra(                                                            │
│    95 │   │   │   │   │   │   args=args,                                                         │
│    96 │   │   │   │   │   │   args_parser=args_parser,                                           │
│    97 │   │   │   │   │   │   task_function=task_function,                                       │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/site-packages/hydra/_internal/utils.py:394 in _run_hydra  │
│                                                                                                  │
│   391 │   │                                                                                      │
│   392 │   │   if args.run or args.multirun:                                                      │
│   393 │   │   │   run_mode = hydra.get_mode(config_name=config_name, overrides=overrides)        │
│ ❱ 394 │   │   │   _run_app(                                                                      │
│   395 │   │   │   │   run=args.run,                                                              │
│   396 │   │   │   │   multirun=args.multirun,                                                    │
│   397 │   │   │   │   mode=run_mode,                                                             │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/site-packages/hydra/_internal/utils.py:457 in _run_app    │
│                                                                                                  │
│   454 │   │   │   overrides.extend(["hydra.mode=MULTIRUN"])                                      │
│   455 │                                                                                          │
│   456 │   if mode == RunMode.RUN:                                                                │
│ ❱ 457 │   │   run_and_report(                                                                    │
│   458 │   │   │   lambda: hydra.run(                                                             │
│   459 │   │   │   │   config_name=config_name,                                                   │
│   460 │   │   │   │   task_function=task_function,                                               │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/site-packages/hydra/_internal/utils.py:223 in             │
│ run_and_report                                                                                   │
│                                                                                                  │
│   220 │   │   return func()                                                                      │
│   221 │   except Exception as ex:                                                                │
│   222 │   │   if _is_env_set("HYDRA_FULL_ERROR") or is_under_debugger():                         │
│ ❱ 223 │   │   │   raise ex                                                                       │
│   224 │   │   else:                                                                              │
│   225 │   │   │   try:                                                                           │
│   226 │   │   │   │   if isinstance(ex, CompactHydraException):                                  │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/site-packages/hydra/_internal/utils.py:220 in             │
│ run_and_report                                                                                   │
│                                                                                                  │
│   217                                                                                            │
│   218 def run_and_report(func: Any) -> Any:                                                      │
│   219 │   try:                                                                                   │
│ ❱ 220 │   │   return func()                                                                      │
│   221 │   except Exception as ex:                                                                │
│   222 │   │   if _is_env_set("HYDRA_FULL_ERROR") or is_under_debugger():                         │
│   223 │   │   │   raise ex                                                                       │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/site-packages/hydra/_internal/utils.py:458 in <lambda>    │
│                                                                                                  │
│   455 │                                                                                          │
│   456 │   if mode == RunMode.RUN:                                                                │
│   457 │   │   run_and_report(                                                                    │
│ ❱ 458 │   │   │   lambda: hydra.run(                                                             │
│   459 │   │   │   │   config_name=config_name,                                                   │
│   460 │   │   │   │   task_function=task_function,                                               │
│   461 │   │   │   │   overrides=overrides,                                                       │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/site-packages/hydra/_internal/hydra.py:132 in run         │
│                                                                                                  │
│   129 │   │   callbacks.on_run_end(config=cfg, config_name=config_name, job_return=ret)          │
│   130 │   │                                                                                      │
│   131 │   │   # access the result to trigger an exception in case the job failed.                │
│ ❱ 132 │   │   _ = ret.return_value                                                               │
│   133 │   │                                                                                      │
│   134 │   │   return ret                                                                         │
│   135                                                                                            │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/site-packages/hydra/core/utils.py:260 in return_value     │
│                                                                                                  │
│   257 │   │   │   sys.stderr.write(                                                              │
│   258 │   │   │   │   f"Error executing job with overrides: {self.overrides}" + os.linesep       │
│   259 │   │   │   )                                                                              │
│ ❱ 260 │   │   │   raise self._return_value                                                       │
│   261 │                                                                                          │
│   262 │   @return_value.setter                                                                   │
│   263 │   def return_value(self, value: Any) -> None:                                            │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/site-packages/hydra/core/utils.py:186 in run_job          │
│                                                                                                  │
│   183 │   │   with env_override(hydra_cfg.hydra.job.env_set):                                    │
│   184 │   │   │   callbacks.on_job_start(config=config, task_function=task_function)             │
│   185 │   │   │   try:                                                                           │
│ ❱ 186 │   │   │   │   ret.return_value = task_function(task_cfg)                                 │
│   187 │   │   │   │   ret.status = JobStatus.COMPLETED                                           │
│   188 │   │   │   except Exception as e:                                                         │
│   189 │   │   │   │   ret.return_value = e                                                       │
│                                                                                                  │
│ /d2/data/chuxiong/direct-preference-optimization/train.py:104 in main                            │
│                                                                                                  │
│   101 │   if 'FSDP' in config.trainer:                                                           │
│   102 │   │   world_size = torch.cuda.device_count()                                             │
│   103 │   │   print('starting', world_size, 'processes for FSDP training')                       │
│ ❱ 104 │   │   mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, refer   │
│   105 │   else:                                                                                  │
│   106 │   │   print('starting single-process worker')                                            │
│   107 │   │   worker_main(0, 1, config, policy, reference_model)                                 │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/site-packages/torch/multiprocessing/spawn.py:239 in spawn │
│                                                                                                  │
│   236 │   │   │      'To use a different start_method use:\n\t\t'                                │
│   237 │   │   │      ' torch.multiprocessing.start_processes(...)' % start_method)               │
│   238 │   │   warnings.warn(msg)                                                                 │
│ ❱ 239 │   return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')           │
│   240                                                                                            │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/site-packages/torch/multiprocessing/spawn.py:188 in       │
│ start_processes                                                                                  │
│                                                                                                  │
│   185 │   │   │   args=(fn, i, args, error_queue),                                               │
│   186 │   │   │   daemon=daemon,                                                                 │
│   187 │   │   )                                                                                  │
│ ❱ 188 │   │   process.start()                                                                    │
│   189 │   │   error_queues.append(error_queue)                                                   │
│   190 │   │   processes.append(process)                                                          │
│   191                                                                                            │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/multiprocessing/process.py:121 in start                   │
│                                                                                                  │
│   118 │   │   assert not _current_process._config.get('daemon'), \                               │
│   119 │   │   │      'daemonic processes are not allowed to have children'                       │
│   120 │   │   _cleanup()                                                                         │
│ ❱ 121 │   │   self._popen = self._Popen(self)                                                    │
│   122 │   │   self._sentinel = self._popen.sentinel                                              │
│   123 │   │   # Avoid a refcycle if the target function holds an indirect                        │
│   124 │   │   # reference to the process object (see bpo-30775)                                  │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/multiprocessing/context.py:288 in _Popen                  │
│                                                                                                  │
│   285 │   │   @staticmethod                                                                      │
│   286 │   │   def _Popen(process_obj):                                                           │
│   287 │   │   │   from .popen_spawn_posix import Popen                                           │
│ ❱ 288 │   │   │   return Popen(process_obj)                                                      │
│   289 │   │                                                                                      │
│   290 │   │   @staticmethod                                                                      │
│   291 │   │   def _after_fork():                                                                 │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/multiprocessing/popen_spawn_posix.py:32 in __init__       │
│                                                                                                  │
│   29 │                                                                                           │
│   30 │   def __init__(self, process_obj):                                                        │
│   31 │   │   self._fds = []                                                                      │
│ ❱ 32 │   │   super().__init__(process_obj)                                                       │
│   33 │                                                                                           │
│   34 │   def duplicate_for_child(self, fd):                                                      │
│   35 │   │   self._fds.append(fd)                                                                │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/multiprocessing/popen_fork.py:19 in __init__              │
│                                                                                                  │
│   16 │   │   util._flush_std_streams()                                                           │
│   17 │   │   self.returncode = None                                                              │
│   18 │   │   self.finalizer = None                                                               │
│ ❱ 19 │   │   self._launch(process_obj)                                                           │
│   20 │                                                                                           │
│   21 │   def duplicate_for_child(self, fd):                                                      │
│   22 │   │   return fd                                                                           │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/multiprocessing/popen_spawn_posix.py:58 in _launch        │
│                                                                                                  │
│   55 │   │   │   cmd = spawn.get_command_line(tracker_fd=tracker_fd,                             │
│   56 │   │   │   │   │   │   │   │   │   │    pipe_handle=child_r)                               │
│   57 │   │   │   self._fds.extend([child_r, child_w])                                            │
│ ❱ 58 │   │   │   self.pid = util.spawnv_passfds(spawn.get_executable(),                          │
│   59 │   │   │   │   │   │   │   │   │   │      cmd, self._fds)                                  │
│   60 │   │   │   self.sentinel = parent_r                                                        │
│   61 │   │   │   with open(parent_w, 'wb', closefd=False) as f:                                  │
│                                                                                                  │
│ /d1/conda3/envs/scx_llm/lib/python3.10/multiprocessing/util.py:452 in spawnv_passfds             │
│                                                                                                  │
│   449 │   passfds = tuple(sorted(map(int, passfds)))                                             │
│   450 │   errpipe_read, errpipe_write = os.pipe()                                                │
│   451 │   try:                                                                                   │
│ ❱ 452 │   │   return _posixsubprocess.fork_exec(                                                 │
│   453 │   │   │   args, [os.fsencode(path)], True, passfds, None, None,                          │
│   454 │   │   │   -1, -1, -1, -1, -1, -1, errpipe_read, errpipe_write,                           │
│   455 │   │   │   False, False, None, None, None, -1, None)                                      │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: bad value(s) in fds_to_keep
eric-mitchell commented 1 year ago

Did you include the ulimit -n 64000 call in your launch command? There's an example usage in the readme.

can you also share the versions of python and torch you're using? Will investigate more later today.

skepsun commented 1 year ago

Just tried to add ulimit -n 64000 but still got the same error message. I am using Python 3.10.11 with torch 2.0.1. Here is a list of installed packages:

Package                  Version      Editable project location
------------------------ ------------ --------------------------
absl-py                  1.4.0
accelerate               0.20.3
aiofiles                 23.1.0
aiohttp                  3.8.4
aiosignal                1.3.1
aliyun-python-sdk-core   2.13.36
aliyun-python-sdk-kms    2.16.1
altair                   5.0.1
antlr4-python3-runtime   4.9.3
anyio                    3.7.0
appdirs                  1.4.4
arrow                    1.2.3
asttokens                2.0.5
async-timeout            4.0.2
attrs                    23.1.0
backcall                 0.2.0
bcrypt                   4.0.1
beautifulsoup4           4.12.2
bitsandbytes             0.39.0
blessed                  1.20.0
brotlipy                 0.7.0
cachetools               5.3.1
cattrs                   23.1.2
certifi                  2021.5.30
cffi                     1.15.1
cfgv                     3.3.1
chardet                  3.0.4
charset-normalizer       2.0.4
cheroot                  10.0.0
click                    8.1.3
cmake                    3.26.3
coati                    1.0.0
colossalai               0.2.8
contexttimer             0.3.3
contourpy                1.0.7
cpm-kernels              1.0.11
crcmod                   1.7
croniter                 1.3.15
cryptography             39.0.1
cycler                   0.11.0
data-serialize           0.2.1
dataclasses-json         0.5.7
datasets                 2.12.0
dateutils                0.6.12
debugpy                  1.5.1
decorator                5.1.1
deep-training            0.1.10.post1
deepdiff                 6.3.0
deepspeed                0.9.5
delta-center-client      0.0.4
dill                     0.3.6
distlib                  0.3.6
docker-pycreds           0.4.0
einops                   0.6.1
exceptiongroup           1.1.1
executing                0.8.3
fabric                   3.1.0
faiss-gpu                1.7.2
fastapi                  0.95.2
fastdatasets             0.9.7.post0
ffmpy                    0.3.0
filelock                 3.12.0
fire                     0.5.0
flash-attn               1.0.3.post0
fonttools                4.39.4
frozenlist               1.3.3
fschat                   0.2.15       /d1/data/chuxiong/FastChat
fsspec                   2023.5.0
functorch                1.13.1
gensim                   4.3.1
gitdb                    4.0.10
GitPython                3.1.31
gmpy2                    2.1.2
google-auth              2.20.0
google-auth-oauthlib     1.0.0
google-trans-new         1.1.9
gpustat                  1.1
gradio                   3.35.2
gradio_client            0.2.7
greenlet                 2.0.2
grpcio                   1.51.3
h11                      0.9.0
h2                       3.2.0
hjson                    3.1.0
hpack                    3.0.0
hstspreload              2023.1.1
httpcore                 0.9.1
httpx                    0.13.3
huggingface-hub          0.14.1
hydra-core               1.3.2
hyperframe               5.2.0
identify                 2.5.24
idna                     3.2
inquirer                 3.1.3
invoke                   2.1.2
ipykernel                6.15.0
ipython                  8.12.0
itsdangerous             2.1.2
jaraco.functools         3.7.0
jedi                     0.18.1
jieba                    0.42.1
Jinja2                   3.1.2
jmespath                 0.10.0
joblib                   1.2.0
jsonlines                3.1.0
jsonschema               4.17.3
jupyter_client           8.1.0
jupyter_core             5.3.0
kiwisolver               1.4.4
langchain                0.0.189
latex2mathml             3.76.0
lightning                2.0.4
lightning-cloud          0.5.37
lightning-utilities      0.8.0
linkify-it-py            2.0.2
lit                      16.0.5.post0
loguru                   0.7.0
loralib                  0.1.1
Markdown                 3.4.3
markdown-it-py           2.2.0
markdown2                2.4.8
MarkupSafe               2.1.1
marshmallow              3.19.0
marshmallow-enum         1.5.1
matplotlib               3.7.1
matplotlib-inline        0.1.6
mdit-py-plugins          0.3.3
mdtex2html               1.2.0
mdurl                    0.1.2
mkl-fft                  1.3.6
mkl-random               1.2.2
mkl-service              2.4.0
more-itertools           9.1.0
mpmath                   1.2.1
msgpack                  1.0.5
multidict                6.0.4
multiprocess             0.70.14
mypy-extensions          1.0.0
nest-asyncio             1.5.6
networkx                 2.8.4
nh3                      0.2.13
ninja                    1.11.1
nltk                     3.8.1
nodeenv                  1.8.0
numexpr                  2.8.4
numpy                    1.25.0
numpy-io                 0.0.3
nvidia-cublas-cu11       11.10.3.66
nvidia-cuda-nvrtc-cu11   11.7.99
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cudnn-cu11        8.5.0.96
nvidia-ml-py             11.525.112
oauthlib                 3.2.2
omegaconf                2.3.0
openai                   0.27.8
openapi-schema-pydantic  1.2.4
opendelta                0.3.2
optree                   0.9.1
ordered-set              4.1.0
orjson                   3.9.0
oss2                     2.15.0
packaging                23.0
pandas                   1.5.2
paramiko                 3.2.0
parso                    0.8.3
pathtools                0.1.2
peft                     0.4.0.dev0
pexpect                  4.8.0
pickleshare              0.7.5
Pillow                   9.4.0
pip                      23.1.2
platformdirs             2.5.2
pre-commit               3.3.2
prompt-toolkit           3.0.36
protobuf                 3.20.3
psutil                   5.9.0
ptyprocess               0.7.0
pure-eval                0.2.2
py-cpuinfo               9.0.0
pyarrow                  12.0.0
pyasn1                   0.5.0
pyasn1-modules           0.3.0
pycparser                2.21
pycryptodome             3.18.0
pydantic                 1.10.8
pydub                    0.25.1
Pygments                 2.15.1
PyJWT                    2.7.0
PyNaCl                   1.5.0
pyOpenSSL                23.0.0
pyparsing                3.0.9
pyre-extensions          0.0.23
pyrsistent               0.19.3
PySocks                  1.7.1
python-dateutil          2.8.2
python-dotenv            0.19.0
python-editor            1.0.4
python-multipart         0.0.6
python-rapidjson         1.10
pytorch-lightning        2.0.4
pytz                     2023.3
PyYAML                   6.0
pyzmq                    25.1.0
ray                      2.4.0
readchar                 4.0.5
regex                    2023.5.5
requests                 2.26.0
requests-oauthlib        1.3.1
responses                0.18.0
rfc3986                  1.5.0
rich                     13.4.1
rouge-chinese            1.0.3
rsa                      4.9
safetensors              0.3.1
scikit-learn             1.2.2
scipy                    1.10.1
seaborn                  0.12.2
semantic-version         2.10.0
sentencepiece            0.1.99
sentry-sdk               1.24.0
seqmetric                0.1.2
setproctitle             1.3.2
setuptools               67.8.0
shortuuid                1.0.11
six                      1.16.0
sklearn                  0.0.post5
smart-open               6.3.0
smmap                    5.0.0
sniffio                  1.3.0
soupsieve                2.4.1
SQLAlchemy               2.0.15
sse-starlette            1.6.1
stack-data               0.2.0
starlette                0.27.0
starsessions             1.3.0
svgwrite                 1.4.3
sympy                    1.11.1
tabulate                 0.9.0
tenacity                 8.2.2
tensor-parallel          1.2.8
tensorboard              2.13.0
tensorboard-data-server  0.7.1
tensorboardX             2.6
termcolor                2.3.0
text2vec                 1.2.1
tfrecords                0.2.6
threadpoolctl            3.1.0
tiktoken                 0.4.0
tokenizers               0.13.3
toolz                    0.12.0
torch                    2.0.1
torchaudio               2.0.2
torchinfo                1.8.0
torchmetrics             0.11.4
torchtyping              0.1.4
torchvision              0.15.2
tornado                  6.2
tqdm                     4.65.0
traitlets                5.7.1
transformers             4.30.2
translate-json           0.0.2
triton                   2.0.0
tritonclient             2.34.0
trl                      0.4.4
trlx                     0.6.0
typeguard                4.0.0
typing_extensions        4.6.3
typing-inspect           0.9.0
tzdata                   2023.3
uc-micro-py              1.0.2
urllib3                  1.25
uvicorn                  0.22.0
virtualenv               20.21.0
wandb                    0.15.3
wavedrom                 2.0.3.post3
wcwidth                  0.2.5
web.py                   0.62
websocket-client         1.6.1
websockets               11.0.3
Werkzeug                 2.3.6
wheel                    0.38.4
xformers                 0.0.16
xxhash                   3.2.0
yacs                     0.1.8
yarl                     1.9.2
zstandard                0.21.0
skepsun commented 1 year ago

I changed kwarg low_cpu_mem_usage to False when loading reference_model and this error just disappeared!

eric-mitchell commented 1 year ago

Sorry for the slow update on this! Glad to hear your issue was resolved. That's pretty mysterious. Have you made any other changes to the codebase? I haven't seen this issue on our end.

eric-mitchell commented 1 year ago

@skepsun just wanted to follow up- is everything working as expected for you? Feel free to re-open if you have any other questions, but I'll close this issue for now.

GeekDream-x commented 8 months ago

Met the same problem. Solved by setting low_cpu_mem_usage = False when loading both the policy and the reference_model.