Closed c3ianwu closed 1 week ago
Ok I think I have found the issue.
In my code I was calling with unwrap_model_for_generation(model, accelerator) as unwrapped_model
multiple times in different places, opening and closing with the context manager e.g.
with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
# generate A
# do some stuff
with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
# generate B
# do more stuff
This problem seems to go away by just calling the context manager once:
with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
# generate A
# do some stuff
# generate B
# do more stuff
Code based on a forked version of trl.
As this is based on my own modified version of trl I realise you might not be of much help
Indeed. Unfortunately we don't have time for tech support. The most we can do is help with the original codebase.
Great that you've found the solution! Thanks for sharing it!
System Info
Code based on a forked version of trl.
Package Version
accelerate 0.33.0 aiohappyeyeballs 2.4.3 aiohttp 3.10.10 aiosignal 1.3.1 annotated-types 0.7.0 anyio 4.6.2.post1 asttokens 2.0.5 async-timeout 4.0.3 attrs 24.2.0 bitsandbytes 0.44.1 certifi 2024.8.30 charset-normalizer 3.4.0 click 8.1.7 cloudpickle 3.1.0 comm 0.2.1 compressed-tensors 0.6.0 contourpy 1.3.0 cycler 0.12.1 datasets 3.0.2 debugpy 1.6.7 decorator 5.1.1 deepspeed 0.15.3 dill 0.3.8 diskcache 5.6.3 distro 1.9.0 docker-pycreds 0.4.0 docstring_parser 0.16 einops 0.8.0 exceptiongroup 1.2.0 executing 0.8.3 fastapi 0.115.4 filelock 3.16.1 fonttools 4.54.1 frozenlist 1.5.0 fsspec 2024.9.0 gguf 0.10.0 gitdb 4.0.11 GitPython 3.1.43 h11 0.14.0 hjson 3.1.0 httpcore 1.0.6 httptools 0.6.4 httpx 0.27.2 huggingface-hub 0.26.1 idna 3.10 importlib_metadata 8.5.0 interegular 0.3.3 ipykernel 6.29.5 ipython 8.27.0 jedi 0.19.1 Jinja2 3.1.4 jiter 0.6.1 jsonschema 4.23.0 jsonschema-specifications 2024.10.1 jupyter_client 8.6.0 jupyter_core 5.7.2 kiwisolver 1.4.7 lark 1.2.2 llvmlite 0.43.0 lm-format-enforcer 0.10.6 markdown-it-py 3.0.0 MarkupSafe 3.0.2 matplotlib 3.9.2 matplotlib-inline 0.1.6 mdurl 0.1.2 mistral_common 1.4.4 mpmath 1.3.0 msgpack 1.1.0 msgspec 0.18.6 multidict 6.1.0 multiprocess 0.70.16 nest-asyncio 1.6.0 networkx 3.4.2 ninja 1.11.1.1 numba 0.60.0 numpy 1.26.4 nvidia-cublas-cu12 12.1.3.1 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu12 9.1.0.70 nvidia-cufft-cu12 11.0.2.54 nvidia-curand-cu12 10.3.2.106 nvidia-cusolver-cu12 11.4.5.107 nvidia-cusparse-cu12 12.1.0.106 nvidia-ml-py 12.560.30 nvidia-nccl-cu12 2.20.5 nvidia-nvjitlink-cu12 12.4.127 nvidia-nvtx-cu12 12.1.105 openai 1.52.2 opencv-python-headless 4.10.0.84 outlines 0.0.46 packaging 24.1 pandas 2.2.3 parso 0.8.3 partial-json-parser 0.2.1.1.post4 peft 0.13.2 pexpect 4.8.0 pillow 10.4.0 pip 24.2 platformdirs 3.10.0 prometheus_client 0.21.0 prometheus-fastapi-instrumentator 7.0.0 prompt-toolkit 3.0.43 propcache 0.2.0 protobuf 5.28.3 psutil 5.9.0 ptyprocess 0.7.0 pure-eval 0.2.2 py-cpuinfo 9.0.0 pyairports 2.1.1 pyarrow 17.0.0 pycountry 24.6.1 pydantic 2.9.2 pydantic_core 2.23.4 Pygments 2.15.1 pyparsing 3.2.0 python-dateutil 2.9.0.post0 python-dotenv 1.0.1 pytz 2024.2 PyYAML 6.0.2 pyzmq 25.1.2 ray 2.38.0 referencing 0.35.1 regex 2024.9.11 requests 2.32.3 rich 13.9.3 rpds-py 0.20.0 safetensors 0.4.5 sentencepiece 0.2.0 sentry-sdk 2.17.0 setproctitle 1.3.3 setuptools 75.1.0 shtab 1.7.1 six 1.16.0 smmap 5.0.1 sniffio 1.3.1 stack-data 0.2.0 starlette 0.41.2 sympy 1.13.1 tiktoken 0.7.0 tokenizers 0.20.1 torch 2.4.0 torchvision 0.19.0 tornado 6.4.1 tqdm 4.66.5 traitlets 5.14.3 transformers 4.46.0 triton 3.0.0 typing_extensions 4.11.0 tyro 0.8.14 tzdata 2024.2 urllib3 2.2.3 uvicorn 0.32.0 uvloop 0.21.0 vllm 0.6.3.post1 wandb 0.18.5 watchfiles 0.24.0 wcwidth 0.2.5 websockets 13.1 wheel 0.44.0 xformers 0.0.27.post2 xxhash 3.5.0 yarl 1.16.0 zipp 3.20.2
Information
Tasks
examples
folderReproduction
Running into this error when using PPOv2.
As this is based on my own modified version of trl I realise you might not be of much help. However it does seem like this error has been seen before by other users.
To help me debug, can someone tell me why this error occurs, and what some situations are that this arises?
I was able to run PPOv2 until my most recent changes. I'm trying to figure out what change is leading to this. Any help would be greatly appreciated!
Expected behavior
No error.