OpenLLMAI / OpenRLHF

An Easy-to-use, Scalable and High-performance RLHF Framework (70B+ PPO Full Tuning & Iterative DPO & LoRA & Mixtral)
https://openrlhf.readthedocs.io/
Apache License 2.0
1.71k stars 160 forks source link

qwen2 sft后的模型使用kto训练loss nan #326

Open vincezengqiang opened 2 weeks ago

vincezengqiang commented 2 weeks ago

使用qwen2 7b在业务数据集后sft后,再使用kto训练loss nan,求教可能是什么原因?

deepspeed --num_gpus 8train_kto.py \ --save_path xx \ --save_steps 300 \ --logging_steps 2 \ --micro_train_batch_size 2 \ --pretrain sft_model \ --max_epochs 5 \ --max_len 8192 \ --zero_stage 3 \ --beta 0.1 \ --learning_rate 5e-7 \ --dataset kto_pair.jsonl \ --dataset_probs 1.0 \ --gradient_checkpointing \ --vanilla_loss

Package Version


absl-py 2.1.0 accelerate 0.28.0 aiohttp 3.9.5 aiosignal 1.3.1 annotated-types 0.6.0 apex 0.1+2305.ppu appdirs 1.4.4 astunparse 1.6.3 async-timeout 4.0.3 attrs 23.1.0 audioread 3.0.1 bitsandbytes 0.43.1 cachetools 5.3.3 certifi 2022.9.24 cffi 1.15.0 charset-normalizer 2.0.12 click 8.1.7 cloud-tpu-client 0.10 cnp 0.1 coloredlogs 15.0.1 dataclasses 0.6 datasets 2.19.0 dateparser 1.1.7 DateTime 5.5 decorator 5.1.1 deepspeed 0.14.3 dill 0.3.8 docstring_parser 0.16 einops 0.6.1 elastic-transport 8.4.0 elasticsearch 8.5.3 exceptiongroup 1.1.3 filelock 3.6.0 flash_attn 2.4.2 flatbuffers 23.1.4 frozenlist 1.4.1 fsspec 2024.3.1 future 0.18.2 google-api-core 1.34.1 google-api-python-client 1.8.0 google-auth 2.29.0 google-auth-httplib2 0.2.0 googleapis-common-protos 1.63.0 hjson 3.1.0 httplib2 0.22.0 huggingface-hub 0.23.4 humanfriendly 10.0 idna 3.3 importlib-metadata 4.11.3 iniconfig 2.0.0 intel-openmp 2021.1.1 iopath 0.1.10 Jinja2 3.1.2 joblib 1.1.0 jsonschema 4.22.0 jsonschema-specifications 2023.12.1 llvmlite 0.42.0 loguru 0.7.2 loralib 0.1.2 markdown-it-py 3.0.0 MarkupSafe 2.1.3 mdurl 0.1.2 mkl 2021.1.1 mkl-include 2021.1.1 model-prof 0.0.4 mpmath 1.2.1 msgpack 1.0.8 multidict 6.0.5 multiprocess 0.70.16 mypy-extensions 1.0.0 networkx 3.0 ninja 1.10.0 numba 0.59.0 numpy 1.23.5 nvidia-dali-cuda120 1.20.0 nvidia-ml-py 12.555.43 oauth2client 4.1.3 packaging 21.3 pandas 2.2.2 peft 0.11.1 Pillow 10.1.0 pip 22.0.4 pluggy 1.3.0 pooch 1.6.0 portalocker 2.6.0 protobuf 3.20.3 psutil 5.9.8 py-cpuinfo 9.0.0 pyarrow 16.0.0 pyarrow-hotfix 0.6 pyasn1 0.6.0 pyasn1_modules 0.4.0 pybind11 2.12.0 pycparser 2.21 pydantic 2.7.1 pydantic_core 2.18.2 Pygments 2.18.0 pynvml 11.5.0 pyparsing 3.0.8 pyre-extensions 0.0.29 pytest 7.2.0 python-dateutil 2.8.2 pytz 2022.1 pytz-deprecation-shim 0.1.0.post0 PyYAML 6.0 ray 2.24.0 referencing 0.35.1 regex 2022.10.31 requests 2.27.1 resampy 0.4.2 rich 13.7.1 rpds-py 0.18.1 rsa 4.9 safetensors 0.4.3 scikit-learn 1.1.3 scipy 1.9.3 sentencepiece 0.2.0 setuptools 57.5.0 shtab 1.7.1 six 1.16.0 SoundFile 0.10.3.post1 sympy 1.11.1 tbb 2021.1.1 tensorboardX 2.6.2.2 threadpoolctl 3.1.0 tifffile 2022.10.10 tiktoken 0.5.2 timm 0.8.22.dev0 tokenizers 0.19.1 tomli 2.0.1 torch 2.2.2.2 torchaudio 2.2.2.2 torchdata 0.6.1+e1feeb2 torchtext 0.17.2.2 torchvision 0.17.2.2 tqdm 4.64.0 transformer_engine 1.4.0+0fbc76a transformers 4.41.2 transformers-stream-generator 0.0.4 triton 2.2.0 trl 0.8.6 typing 3.7.4.3 typing_extensions 4.11.0 typing-inspect 0.9.0 tyro 0.8.4 tzdata 2022.7 tzlocal 4.2 uritemplate 3.0.1 urllib3 1.26.16 wheel 0.40.0 xformers 0.0.22+02e68e1.d20240422 xxhash 3.4.1 yarl 1.9.4 zipp 3.8.0 zope.interface 6.3

hijkzzz commented 2 weeks ago

可能是数据集和数据集相关代码没对上 参考 https://github.com/OpenLLMAI/OpenRLHF/blob/9d8b3fdac345f6a18b37d73c53bfb95a652d1db2/openrlhf/datasets/unpaired_preference_dataset.py#L12 其次就是关闭flash_attn试一下

vincezengqiang commented 2 weeks ago

可能是数据集和数据集相关代码没对上 参考

https://github.com/OpenLLMAI/OpenRLHF/blob/9d8b3fdac345f6a18b37d73c53bfb95a652d1db2/openrlhf/datasets/unpaired_preference_dataset.py#L12

其次就是关闭flash_attn试一下

应该不是数据的问题,使用相同的数据+qwen1.5,loss正常,模型正常收敛;我试一下关闭flash_attn