hiyouga / LLaMA-Factory

Unified Efficient Fine-Tuning of 100+ LLMs (ACL 2024)
https://arxiv.org/abs/2403.13372
Apache License 2.0
35.32k stars 4.36k forks source link

更新代码后,PPO训练出现如下问题 Logging with `tensorboard` requires a `logging_dir` to be passed in. #1163

Closed tianwang2021 closed 1 year ago

tianwang2021 commented 1 year ago

您好,我目前已经跑通sft、rm、dpo,但是我在PPO训练时出现如下错误 Traceback (most recent call last): File "src/train_bash.py", line 14, in main() File "src/train_bash.py", line 5, in main run_exp() File "/home/wt/llm_project/LLaMA-Efficient-Tuning/src/llmtuner/tuner/tune.py", line 30, in run_exp run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) File "/home/wt/llm_project/LLaMA-Efficient-Tuning/src/llmtuner/tuner/ppo/workflow.py", line 67, in run_ppo ppo_trainer = CustomPPOTrainer( File "/home/wt/llm_project/LLaMA-Efficient-Tuning/src/llmtuner/tuner/ppo/trainer.py", line 39, in init PPOTrainer.init(self, **kwargs) File "/usr/local/python3.8/lib/python3.8/site-packages/trl/trainer/ppo_trainer.py", line 191, in init self.accelerator = Accelerator( File "/usr/local/python3.8/lib/python3.8/site-packages/accelerate/accelerator.py", line 369, in init trackers = filter_trackers(log_with, self.logging_dir) File "/usr/local/python3.8/lib/python3.8/site-packages/accelerate/tracking.py", line 725, in filter_trackers raise ValueError( ValueError: Logging with tensorboard requires a logging_dir to be passed in.

我的运行脚本为 CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --stage ppo \ --model_name_or_path /home/wt/llm_project/model/chatglm2-6b-20230828 \ --do_train \ --dataset alpaca_gpt4_zh \ --template chatglm2 \ --finetuning_type lora \ --lora_target query_key_value \ --resume_lora_training False \ --checkpoint_dir ./output/chatglm2-6b-sft-test-1012/checkpoint-20 \ --reward_model ./output/chatglm2-6b-rm-test-1012/checkpoint-20 \ --output_dir ./output/chatglm2-6b-ppo-test-1012 \ --per_device_train_batch_size 2 \ --gradient_accumulation_steps 4 \ --lr_scheduler_type cosine \ --logging_steps 10 \ --save_steps 10 \ --learning_rate 1e-5 \ --num_train_epochs 1.0 \ --plot_loss

环境配置为 Package Version


absl-py 0.13.0 accelerate 0.21.0 aio-pika 6.8.2 aiofiles 23.1.0 aiohttp 3.8.5 aiormq 3.3.1 aiosignal 1.3.1 altair 4.2.2 anyio 3.7.0 appdirs 1.4.4 APScheduler 3.7.0 astunparse 1.6.3 async-generator 1.10 async-timeout 4.0.3 attrs 21.2.0 backports.zoneinfo 0.2.1 bidict 0.22.1 blinker 1.6.2 blis 0.7.10 boto3 1.28.40 botocore 1.31.40 CacheControl 0.12.14 cachetools 5.3.1 catalogue 2.0.9 certifi 2023.5.7 cffi 1.15.1 chardet 3.0.4 charset-normalizer 3.1.0 clang 5.0 click 8.1.3 cloudpickle 1.6.0 colorclass 2.2.2 coloredlogs 15.0.1 colorhash 1.0.4 confection 0.1.1 contourpy 1.0.7 cpm-kernels 1.0.11 cryptography 41.0.3 cycler 0.11.0 cymem 2.0.7 dask 2021.11.2 dataclasses-json 0.5.14 datasets 2.12.0 decorator 5.1.1 deepspeed 0.9.3 dill 0.3.6 dm-tree 0.1.8 dnspython 1.16.0 docker-pycreds 0.4.0 docopt 0.6.2 entrypoints 0.4 exceptiongroup 1.1.1 fastapi 0.95.1 fbmessenger 6.0.0 ffmpy 0.3.0 filelock 3.12.0 fire 0.5.0 flatbuffers 1.12 fonttools 4.40.0 frozenlist 1.3.3 fsspec 2023.5.0 future 0.18.3 gast 0.4.0 gitdb 4.0.10 GitPython 3.1.31 google-auth 2.22.0 google-auth-oauthlib 1.0.0 google-pasta 0.2.0 gradio 3.47.1 gradio_client 0.6.0 greenlet 2.0.2 grpcio 1.57.0 h11 0.14.0 h5py 3.1.0 hjson 3.1.0 httpcore 0.17.2 httptools 0.6.0 httpx 0.24.1 huggingface-hub 0.17.3 humanfriendly 10.0 icetk 0.0.4 idna 3.4 importlib-metadata 6.6.0 importlib-resources 5.12.0 jieba 0.42.1 Jinja2 3.1.2 jmespath 1.0.1 joblib 1.0.1 jsonpickle 2.0.0 jsonschema 3.2.0 kafka-python 2.0.2 keras 2.6.0 Keras-Preprocessing 1.1.2 kiwisolver 1.4.4 langchain 0.0.279 langcodes 3.3.0 langsmith 0.0.33 latex2mathml 3.76.0 linkify-it-py 2.0.2 locket 1.0.0 Markdown 3.4.3 markdown-it-py 2.2.0 MarkupSafe 2.1.2 marshmallow 3.20.1 matplotlib 3.3.4 mattermostwrapper 2.2 mdit-py-plugins 0.3.3 mdtex2html 1.2.0 mdurl 0.1.2 mpmath 1.3.0 msgpack 1.0.5 multidict 5.2.0 multiprocess 0.70.14 murmurhash 1.0.9 mypy-extensions 1.0.0 networkx 2.6.3 ninja 1.11.1.1 nltk 3.8.1 numexpr 2.8.5 numpy 1.20.3 nvidia-cublas-cu12 12.1.3.1 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu12 8.9.2.26 nvidia-cufft-cu12 11.0.2.54 nvidia-curand-cu12 10.3.2.106 nvidia-cusolver-cu12 11.4.5.107 nvidia-cusparse-cu12 12.1.0.106 nvidia-nccl-cu12 2.18.1 nvidia-nvjitlink-cu12 12.2.140 nvidia-nvtx-cu12 12.1.105 oauthlib 3.2.2 opt-einsum 3.3.0 orjson 3.9.1 packaging 20.9 pamqp 2.3.0 pandas 2.0.2 partd 1.4.0 pathtools 0.1.2 pathy 0.10.2 peft 0.4.0 Pillow 9.5.0 pip 23.2.1 pkgutil_resolve_name 1.3.10 preshed 3.0.8 prompt-toolkit 2.0.10 protobuf 3.20.0 psutil 5.9.5 psycopg2-binary 2.9.7 py-cpuinfo 9.0.0 pyarrow 12.0.0 pyasn1 0.5.0 pyasn1-modules 0.3.0 pycparser 2.21 pydantic 1.10.11 pydeck 0.8.1b0 pydot 1.4.2 pydub 0.25.1 Pygments 2.15.1 PyJWT 2.8.0 pykwalify 1.8.0 pymongo 3.10.1 Pympler 1.0.1 pyparsing 3.0.9 pyrsistent 0.19.3 pyTelegramBotAPI 3.8.3 python-crfsuite 0.9.9 python-dateutil 2.8.2 python-engineio 4.7.0 python-multipart 0.0.6 python-socketio 5.9.0 pytz 2021.3 PyYAML 6.0 questionary 1.10.0 randomname 0.1.5 rasa 3.0.4 rasa-sdk 3.1.0 redis 3.5.3 regex 2023.10.3 requests 2.31.0 requests-oauthlib 1.3.1 requests-toolbelt 1.0.0 responses 0.18.0 rich 13.4.2 rocketchat-API 1.16.0 rouge-chinese 1.0.3 rsa 4.9 ruamel.yaml 0.16.13 ruamel.yaml.clib 0.2.7 s3transfer 0.6.2 safetensors 0.3.3 sanic 21.9.3 Sanic-Cors 1.0.1 sanic-jwt 1.8.0 sanic-plugin-toolkit 1.2.1 sanic-routing 0.7.2 scikit-learn 0.24.2 scipy 1.10.1 semantic-version 2.10.0 sentencepiece 0.1.99 sentry-sdk 1.3.1 setproctitle 1.3.2 setuptools 49.2.1 six 1.15.0 sklearn-crfsuite 0.3.6 slackclient 2.9.4 smart-open 6.3.0 smmap 5.0.0 sniffio 1.3.0 spacy 3.6.1 spacy-legacy 3.0.12 spacy-loggers 1.0.4 spacy-pkuseg 0.0.32 SQLAlchemy 1.4.49 srsly 2.4.7 sse-starlette 1.6.5 starlette 0.26.1 streamlit 1.22.0 sympy 1.12 tabulate 0.9.0 tarsafe 0.0.3 tenacity 8.2.2 tensorboard 2.14.0 tensorboard-data-server 0.7.1 tensorflow 2.6.1 tensorflow-addons 0.14.0 tensorflow-estimator 2.6.0 tensorflow-hub 0.12.0 tensorflow-probability 0.13.0 tensorflow-text 2.6.0 termcolor 1.1.0 terminaltables 3.1.10 thinc 8.1.12 threadpoolctl 3.2.0 tiktoken 0.5.1 tokenizers 0.13.3 toml 0.10.2 toolz 0.12.0 torch 2.1.0 torchaudio 0.13.0+cu117 torchvision 0.14.0+cu116 tornado 6.3.2 tqdm 4.65.0 transformers 4.31.0 triton 2.1.0 trl 0.7.1 twilio 6.50.1 typeguard 2.13.3 typer 0.9.0 typing_extensions 4.7.1 typing-inspect 0.9.0 typing-utils 0.1.0 tzdata 2023.3 tzlocal 2.1 uc-micro-py 1.0.2 ujson 4.3.0 urllib3 1.26.16 uvicorn 0.22.0 uvloop 0.17.0 validators 0.20.0 wandb 0.15.4 wasabi 1.1.2 watchdog 3.0.0 wcwidth 0.2.6 webexteamssdk 1.6.1 websockets 10.0 Werkzeug 2.3.7 wheel 0.41.2 wrapt 1.12.1 xxhash 3.2.0 yarl 1.9.2 zh-core-web-md 3.6.0 zipp 3.15.0

hiyouga commented 1 year ago

再更新一下代码试试