Open vincezengqiang opened 2 weeks ago
可能是数据集和数据集相关代码没对上 参考 https://github.com/OpenLLMAI/OpenRLHF/blob/9d8b3fdac345f6a18b37d73c53bfb95a652d1db2/openrlhf/datasets/unpaired_preference_dataset.py#L12 其次就是关闭flash_attn试一下
可能是数据集和数据集相关代码没对上 参考
其次就是关闭flash_attn试一下
应该不是数据的问题,使用相同的数据+qwen1.5,loss正常,模型正常收敛;我试一下关闭flash_attn
使用qwen2 7b在业务数据集后sft后,再使用kto训练loss nan,求教可能是什么原因?
deepspeed --num_gpus 8train_kto.py \ --save_path xx \ --save_steps 300 \ --logging_steps 2 \ --micro_train_batch_size 2 \ --pretrain sft_model \ --max_epochs 5 \ --max_len 8192 \ --zero_stage 3 \ --beta 0.1 \ --learning_rate 5e-7 \ --dataset kto_pair.jsonl \ --dataset_probs 1.0 \ --gradient_checkpointing \ --vanilla_loss
Package Version
absl-py 2.1.0 accelerate 0.28.0 aiohttp 3.9.5 aiosignal 1.3.1 annotated-types 0.6.0 apex 0.1+2305.ppu appdirs 1.4.4 astunparse 1.6.3 async-timeout 4.0.3 attrs 23.1.0 audioread 3.0.1 bitsandbytes 0.43.1 cachetools 5.3.3 certifi 2022.9.24 cffi 1.15.0 charset-normalizer 2.0.12 click 8.1.7 cloud-tpu-client 0.10 cnp 0.1 coloredlogs 15.0.1 dataclasses 0.6 datasets 2.19.0 dateparser 1.1.7 DateTime 5.5 decorator 5.1.1 deepspeed 0.14.3 dill 0.3.8 docstring_parser 0.16 einops 0.6.1 elastic-transport 8.4.0 elasticsearch 8.5.3 exceptiongroup 1.1.3 filelock 3.6.0 flash_attn 2.4.2 flatbuffers 23.1.4 frozenlist 1.4.1 fsspec 2024.3.1 future 0.18.2 google-api-core 1.34.1 google-api-python-client 1.8.0 google-auth 2.29.0 google-auth-httplib2 0.2.0 googleapis-common-protos 1.63.0 hjson 3.1.0 httplib2 0.22.0 huggingface-hub 0.23.4 humanfriendly 10.0 idna 3.3 importlib-metadata 4.11.3 iniconfig 2.0.0 intel-openmp 2021.1.1 iopath 0.1.10 Jinja2 3.1.2 joblib 1.1.0 jsonschema 4.22.0 jsonschema-specifications 2023.12.1 llvmlite 0.42.0 loguru 0.7.2 loralib 0.1.2 markdown-it-py 3.0.0 MarkupSafe 2.1.3 mdurl 0.1.2 mkl 2021.1.1 mkl-include 2021.1.1 model-prof 0.0.4 mpmath 1.2.1 msgpack 1.0.8 multidict 6.0.5 multiprocess 0.70.16 mypy-extensions 1.0.0 networkx 3.0 ninja 1.10.0 numba 0.59.0 numpy 1.23.5 nvidia-dali-cuda120 1.20.0 nvidia-ml-py 12.555.43 oauth2client 4.1.3 packaging 21.3 pandas 2.2.2 peft 0.11.1 Pillow 10.1.0 pip 22.0.4 pluggy 1.3.0 pooch 1.6.0 portalocker 2.6.0 protobuf 3.20.3 psutil 5.9.8 py-cpuinfo 9.0.0 pyarrow 16.0.0 pyarrow-hotfix 0.6 pyasn1 0.6.0 pyasn1_modules 0.4.0 pybind11 2.12.0 pycparser 2.21 pydantic 2.7.1 pydantic_core 2.18.2 Pygments 2.18.0 pynvml 11.5.0 pyparsing 3.0.8 pyre-extensions 0.0.29 pytest 7.2.0 python-dateutil 2.8.2 pytz 2022.1 pytz-deprecation-shim 0.1.0.post0 PyYAML 6.0 ray 2.24.0 referencing 0.35.1 regex 2022.10.31 requests 2.27.1 resampy 0.4.2 rich 13.7.1 rpds-py 0.18.1 rsa 4.9 safetensors 0.4.3 scikit-learn 1.1.3 scipy 1.9.3 sentencepiece 0.2.0 setuptools 57.5.0 shtab 1.7.1 six 1.16.0 SoundFile 0.10.3.post1 sympy 1.11.1 tbb 2021.1.1 tensorboardX 2.6.2.2 threadpoolctl 3.1.0 tifffile 2022.10.10 tiktoken 0.5.2 timm 0.8.22.dev0 tokenizers 0.19.1 tomli 2.0.1 torch 2.2.2.2 torchaudio 2.2.2.2 torchdata 0.6.1+e1feeb2 torchtext 0.17.2.2 torchvision 0.17.2.2 tqdm 4.64.0 transformer_engine 1.4.0+0fbc76a transformers 4.41.2 transformers-stream-generator 0.0.4 triton 2.2.0 trl 0.8.6 typing 3.7.4.3 typing_extensions 4.11.0 typing-inspect 0.9.0 tyro 0.8.4 tzdata 2022.7 tzlocal 4.2 uritemplate 3.0.1 urllib3 1.26.16 wheel 0.40.0 xformers 0.0.22+02e68e1.d20240422 xxhash 3.4.1 yarl 1.9.4 zipp 3.8.0 zope.interface 6.3