您好,我目前已经跑通sft、rm、dpo,但是我在PPO训练时出现如下错误
Traceback (most recent call last):
File "src/train_bash.py", line 14, in
main()
File "src/train_bash.py", line 5, in main
run_exp()
File "/home/wt/llm_project/LLaMA-Efficient-Tuning/src/llmtuner/tuner/tune.py", line 30, in run_exp
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
File "/home/wt/llm_project/LLaMA-Efficient-Tuning/src/llmtuner/tuner/ppo/workflow.py", line 67, in run_ppo
ppo_trainer = CustomPPOTrainer(
File "/home/wt/llm_project/LLaMA-Efficient-Tuning/src/llmtuner/tuner/ppo/trainer.py", line 39, in init
PPOTrainer.init(self, **kwargs)
File "/usr/local/python3.8/lib/python3.8/site-packages/trl/trainer/ppo_trainer.py", line 191, in init
self.accelerator = Accelerator(
File "/usr/local/python3.8/lib/python3.8/site-packages/accelerate/accelerator.py", line 369, in init
trackers = filter_trackers(log_with, self.logging_dir)
File "/usr/local/python3.8/lib/python3.8/site-packages/accelerate/tracking.py", line 725, in filter_trackers
raise ValueError(
ValueError: Logging with tensorboard requires a logging_dir to be passed in.
您好,我目前已经跑通sft、rm、dpo,但是我在PPO训练时出现如下错误 Traceback (most recent call last): File "src/train_bash.py", line 14, in
main()
File "src/train_bash.py", line 5, in main
run_exp()
File "/home/wt/llm_project/LLaMA-Efficient-Tuning/src/llmtuner/tuner/tune.py", line 30, in run_exp
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
File "/home/wt/llm_project/LLaMA-Efficient-Tuning/src/llmtuner/tuner/ppo/workflow.py", line 67, in run_ppo
ppo_trainer = CustomPPOTrainer(
File "/home/wt/llm_project/LLaMA-Efficient-Tuning/src/llmtuner/tuner/ppo/trainer.py", line 39, in init
PPOTrainer.init(self, **kwargs)
File "/usr/local/python3.8/lib/python3.8/site-packages/trl/trainer/ppo_trainer.py", line 191, in init
self.accelerator = Accelerator(
File "/usr/local/python3.8/lib/python3.8/site-packages/accelerate/accelerator.py", line 369, in init
trackers = filter_trackers(log_with, self.logging_dir)
File "/usr/local/python3.8/lib/python3.8/site-packages/accelerate/tracking.py", line 725, in filter_trackers
raise ValueError(
ValueError: Logging with
tensorboard
requires alogging_dir
to be passed in.我的运行脚本为 CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --stage ppo \ --model_name_or_path /home/wt/llm_project/model/chatglm2-6b-20230828 \ --do_train \ --dataset alpaca_gpt4_zh \ --template chatglm2 \ --finetuning_type lora \ --lora_target query_key_value \ --resume_lora_training False \ --checkpoint_dir ./output/chatglm2-6b-sft-test-1012/checkpoint-20 \ --reward_model ./output/chatglm2-6b-rm-test-1012/checkpoint-20 \ --output_dir ./output/chatglm2-6b-ppo-test-1012 \ --per_device_train_batch_size 2 \ --gradient_accumulation_steps 4 \ --lr_scheduler_type cosine \ --logging_steps 10 \ --save_steps 10 \ --learning_rate 1e-5 \ --num_train_epochs 1.0 \ --plot_loss
环境配置为 Package Version
absl-py 0.13.0 accelerate 0.21.0 aio-pika 6.8.2 aiofiles 23.1.0 aiohttp 3.8.5 aiormq 3.3.1 aiosignal 1.3.1 altair 4.2.2 anyio 3.7.0 appdirs 1.4.4 APScheduler 3.7.0 astunparse 1.6.3 async-generator 1.10 async-timeout 4.0.3 attrs 21.2.0 backports.zoneinfo 0.2.1 bidict 0.22.1 blinker 1.6.2 blis 0.7.10 boto3 1.28.40 botocore 1.31.40 CacheControl 0.12.14 cachetools 5.3.1 catalogue 2.0.9 certifi 2023.5.7 cffi 1.15.1 chardet 3.0.4 charset-normalizer 3.1.0 clang 5.0 click 8.1.3 cloudpickle 1.6.0 colorclass 2.2.2 coloredlogs 15.0.1 colorhash 1.0.4 confection 0.1.1 contourpy 1.0.7 cpm-kernels 1.0.11 cryptography 41.0.3 cycler 0.11.0 cymem 2.0.7 dask 2021.11.2 dataclasses-json 0.5.14 datasets 2.12.0 decorator 5.1.1 deepspeed 0.9.3 dill 0.3.6 dm-tree 0.1.8 dnspython 1.16.0 docker-pycreds 0.4.0 docopt 0.6.2 entrypoints 0.4 exceptiongroup 1.1.1 fastapi 0.95.1 fbmessenger 6.0.0 ffmpy 0.3.0 filelock 3.12.0 fire 0.5.0 flatbuffers 1.12 fonttools 4.40.0 frozenlist 1.3.3 fsspec 2023.5.0 future 0.18.3 gast 0.4.0 gitdb 4.0.10 GitPython 3.1.31 google-auth 2.22.0 google-auth-oauthlib 1.0.0 google-pasta 0.2.0 gradio 3.47.1 gradio_client 0.6.0 greenlet 2.0.2 grpcio 1.57.0 h11 0.14.0 h5py 3.1.0 hjson 3.1.0 httpcore 0.17.2 httptools 0.6.0 httpx 0.24.1 huggingface-hub 0.17.3 humanfriendly 10.0 icetk 0.0.4 idna 3.4 importlib-metadata 6.6.0 importlib-resources 5.12.0 jieba 0.42.1 Jinja2 3.1.2 jmespath 1.0.1 joblib 1.0.1 jsonpickle 2.0.0 jsonschema 3.2.0 kafka-python 2.0.2 keras 2.6.0 Keras-Preprocessing 1.1.2 kiwisolver 1.4.4 langchain 0.0.279 langcodes 3.3.0 langsmith 0.0.33 latex2mathml 3.76.0 linkify-it-py 2.0.2 locket 1.0.0 Markdown 3.4.3 markdown-it-py 2.2.0 MarkupSafe 2.1.2 marshmallow 3.20.1 matplotlib 3.3.4 mattermostwrapper 2.2 mdit-py-plugins 0.3.3 mdtex2html 1.2.0 mdurl 0.1.2 mpmath 1.3.0 msgpack 1.0.5 multidict 5.2.0 multiprocess 0.70.14 murmurhash 1.0.9 mypy-extensions 1.0.0 networkx 2.6.3 ninja 1.11.1.1 nltk 3.8.1 numexpr 2.8.5 numpy 1.20.3 nvidia-cublas-cu12 12.1.3.1 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu12 8.9.2.26 nvidia-cufft-cu12 11.0.2.54 nvidia-curand-cu12 10.3.2.106 nvidia-cusolver-cu12 11.4.5.107 nvidia-cusparse-cu12 12.1.0.106 nvidia-nccl-cu12 2.18.1 nvidia-nvjitlink-cu12 12.2.140 nvidia-nvtx-cu12 12.1.105 oauthlib 3.2.2 opt-einsum 3.3.0 orjson 3.9.1 packaging 20.9 pamqp 2.3.0 pandas 2.0.2 partd 1.4.0 pathtools 0.1.2 pathy 0.10.2 peft 0.4.0 Pillow 9.5.0 pip 23.2.1 pkgutil_resolve_name 1.3.10 preshed 3.0.8 prompt-toolkit 2.0.10 protobuf 3.20.0 psutil 5.9.5 psycopg2-binary 2.9.7 py-cpuinfo 9.0.0 pyarrow 12.0.0 pyasn1 0.5.0 pyasn1-modules 0.3.0 pycparser 2.21 pydantic 1.10.11 pydeck 0.8.1b0 pydot 1.4.2 pydub 0.25.1 Pygments 2.15.1 PyJWT 2.8.0 pykwalify 1.8.0 pymongo 3.10.1 Pympler 1.0.1 pyparsing 3.0.9 pyrsistent 0.19.3 pyTelegramBotAPI 3.8.3 python-crfsuite 0.9.9 python-dateutil 2.8.2 python-engineio 4.7.0 python-multipart 0.0.6 python-socketio 5.9.0 pytz 2021.3 PyYAML 6.0 questionary 1.10.0 randomname 0.1.5 rasa 3.0.4 rasa-sdk 3.1.0 redis 3.5.3 regex 2023.10.3 requests 2.31.0 requests-oauthlib 1.3.1 requests-toolbelt 1.0.0 responses 0.18.0 rich 13.4.2 rocketchat-API 1.16.0 rouge-chinese 1.0.3 rsa 4.9 ruamel.yaml 0.16.13 ruamel.yaml.clib 0.2.7 s3transfer 0.6.2 safetensors 0.3.3 sanic 21.9.3 Sanic-Cors 1.0.1 sanic-jwt 1.8.0 sanic-plugin-toolkit 1.2.1 sanic-routing 0.7.2 scikit-learn 0.24.2 scipy 1.10.1 semantic-version 2.10.0 sentencepiece 0.1.99 sentry-sdk 1.3.1 setproctitle 1.3.2 setuptools 49.2.1 six 1.15.0 sklearn-crfsuite 0.3.6 slackclient 2.9.4 smart-open 6.3.0 smmap 5.0.0 sniffio 1.3.0 spacy 3.6.1 spacy-legacy 3.0.12 spacy-loggers 1.0.4 spacy-pkuseg 0.0.32 SQLAlchemy 1.4.49 srsly 2.4.7 sse-starlette 1.6.5 starlette 0.26.1 streamlit 1.22.0 sympy 1.12 tabulate 0.9.0 tarsafe 0.0.3 tenacity 8.2.2 tensorboard 2.14.0 tensorboard-data-server 0.7.1 tensorflow 2.6.1 tensorflow-addons 0.14.0 tensorflow-estimator 2.6.0 tensorflow-hub 0.12.0 tensorflow-probability 0.13.0 tensorflow-text 2.6.0 termcolor 1.1.0 terminaltables 3.1.10 thinc 8.1.12 threadpoolctl 3.2.0 tiktoken 0.5.1 tokenizers 0.13.3 toml 0.10.2 toolz 0.12.0 torch 2.1.0 torchaudio 0.13.0+cu117 torchvision 0.14.0+cu116 tornado 6.3.2 tqdm 4.65.0 transformers 4.31.0 triton 2.1.0 trl 0.7.1 twilio 6.50.1 typeguard 2.13.3 typer 0.9.0 typing_extensions 4.7.1 typing-inspect 0.9.0 typing-utils 0.1.0 tzdata 2023.3 tzlocal 2.1 uc-micro-py 1.0.2 ujson 4.3.0 urllib3 1.26.16 uvicorn 0.22.0 uvloop 0.17.0 validators 0.20.0 wandb 0.15.4 wasabi 1.1.2 watchdog 3.0.0 wcwidth 0.2.6 webexteamssdk 1.6.1 websockets 10.0 Werkzeug 2.3.7 wheel 0.41.2 wrapt 1.12.1 xxhash 3.2.0 yarl 1.9.2 zh-core-web-md 3.6.0 zipp 3.15.0