ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
34.14k stars 5.8k forks source link

Dreamer V3 bug only when using GPU #43044

Open apple1113 opened 9 months ago

apple1113 commented 9 months ago

What happened + What you expected to happen

Hi, I am trying to check out dreamerv3 but I encounter a weird bug that only happens when I attempt to set num_gpus_per_learner_worker=1

cuDNN launch failure : input shape ([1,1,256,1]) [[{{node mlp/layer_normalization_3/FusedBatchNormV3}}]] [Op:__inference_forward_train_24736]

Call arguments received by layer 'dreamer_model' (type DreamerModel): • inputs=None • observations=tf.Tensor(shape=(1, 64, 4), dtype=float32) • actions=tf.Tensor(shape=(1, 64, 2), dtype=float32) • is_first=tf.Tensor(shape=(1, 64), dtype=bool) • start_is_terminated_BxT=tf.Tensor(shape=(64,), dtype=bool) 2024-02-07 19:25:07,133 ERROR tune.py:1038 -- Trials did not complete: [DreamerV3_CartPole-v1_87b63_00000]

I have done my research and has tensorflow_probability installed. This error doesn't occur when I use CPU only, any ideas why?

Versions / Dependencies

absl-py 2.1.0 accelerate 0.26.1 aiohttp 3.9.3 aiohttp-cors 0.7.0 aiorwlock 1.4.0 aiosignal 1.3.1 annotated-types 0.6.0 anyio 4.2.0 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 arrow 1.3.0 asttokens 2.4.1 astunparse 1.6.3 async-lru 2.0.4 async-timeout 4.0.3 attrs 23.2.0 Babel 2.14.0 beautifulsoup4 4.12.3 bleach 6.1.0 blessed 1.20.0 blinker 1.4 cachetools 5.3.2 certifi 2024.2.2 cffi 1.16.0 charset-normalizer 3.3.2 click 8.1.7 cloudpickle 3.0.0 colorful 0.5.6 comm 0.2.1 command-not-found 0.3 contextlib2 21.6.0 cryptography 3.4.8 dbus-python 1.2.18 debugpy 1.8.0 decorator 5.1.1 defusedxml 0.7.1 distlib 0.3.8 distro 1.7.0 distro-info 1.1+ubuntu0.1 dm-tree 0.1.8 exceptiongroup 1.2.0 executing 2.0.1 Farama-Notifications 0.0.4 fastapi 0.108.0 fastjsonschema 2.19.1 filelock 3.13.1 flatbuffers 23.5.26 fqdn 1.5.1 frozenlist 1.4.1 fsspec 2024.2.0 gast 0.5.4 google-api-core 2.16.2 google-auth 2.27.0 google-auth-oauthlib 1.2.0 google-pasta 0.2.0 googleapis-common-protos 1.62.0 gpustat 1.1.1 grpcio 1.60.1 gymnasium 0.29.1 h11 0.14.0 h5py 3.10.0 httpcore 1.0.2 httplib2 0.20.2 httptools 0.6.1 httpx 0.26.0 huggingface-hub 0.20.3 idna 3.6 importlib-metadata 4.6.4 ipykernel 6.29.0 ipython 8.21.0 ipython_genutils 0.2.0 ipywidgets 8.1.1 isoduration 20.11.0 jedi 0.19.1 jeepney 0.7.1 Jinja2 3.1.3 json5 0.9.14 jsonpointer 2.4 jsonschema 4.21.1 jsonschema-specifications 2023.12.1 jupyter 1.0.0 jupyter_client 8.6.0 jupyter-console 6.6.3 jupyter_core 5.7.1 jupyter-events 0.9.0 jupyter-lsp 2.2.2 jupyter_server 2.12.5 jupyter_server_terminals 0.5.2 jupyterlab 4.1.0 jupyterlab_pygments 0.3.0 jupyterlab_server 2.25.2 jupyterlab-widgets 3.0.9 keras 2.15.0 keyring 23.5.0 launchpadlib 1.10.16 lazr.restfulclient 0.14.4 lazr.uri 1.0.6 libclang 16.0.6 lz4 4.3.3 Markdown 3.5.2 MarkupSafe 2.1.5 matplotlib-inline 0.1.6 mistune 3.0.2 ml-collections 0.1.1 ml-dtypes 0.2.0 more-itertools 8.10.0 mpmath 1.3.0 msgpack 1.0.7 multidict 6.0.5 nbclient 0.9.0 nbconvert 7.14.2 nbformat 5.9.2 nest-asyncio 1.6.0 netifaces 0.11.0 networkx 3.2.1 notebook 7.0.7 notebook_shim 0.2.3 numpy 1.26.4 nvidia-cublas-cu11 11.11.3.6 nvidia-cublas-cu12 12.2.5.6 nvidia-cuda-cupti-cu11 11.8.87 nvidia-cuda-cupti-cu12 12.2.142 nvidia-cuda-nvcc-cu12 12.2.140 nvidia-cuda-nvrtc-cu11 11.8.89 nvidia-cuda-nvrtc-cu12 12.2.140 nvidia-cuda-runtime-cu11 11.8.89 nvidia-cuda-runtime-cu12 12.2.140 nvidia-cudnn-cu11 8.7.0.84 nvidia-cudnn-cu12 8.9.4.25 nvidia-cufft-cu11 10.9.0.58 nvidia-cufft-cu12 11.0.8.103 nvidia-curand-cu11 10.3.0.86 nvidia-curand-cu12 10.3.3.141 nvidia-cusolver-cu11 11.4.1.48 nvidia-cusolver-cu12 11.5.2.141 nvidia-cusparse-cu11 11.7.5.86 nvidia-cusparse-cu12 12.1.2.141 nvidia-ml-py 12.535.133 nvidia-nccl-cu11 2.19.3 nvidia-nccl-cu12 2.16.5 nvidia-nvjitlink-cu12 12.2.140 nvidia-nvtx-cu11 11.8.86 oauthlib 3.2.0 opencensus 0.11.4 opencensus-context 0.1.3 opencv-python 4.9.0.80 opt-einsum 3.3.0 overrides 7.7.0 packaging 23.2 pandas 2.2.0 pandocfilters 1.5.1 parso 0.8.3 pexpect 4.9.0 pillow 10.2.0 pip 24.0 platformdirs 4.2.0 prometheus-client 0.19.0 prompt-toolkit 3.0.43 protobuf 4.23.4 psutil 5.9.8 ptyprocess 0.7.0 pure-eval 0.2.2 py-spy 0.3.14 pyarrow 15.0.0 pyasn1 0.5.1 pyasn1-modules 0.3.0 pycparser 2.21 pydantic 2.6.1 pydantic_core 2.16.2 pyglet 1.5.28 Pygments 2.17.2 PyGObject 3.42.1 PyJWT 2.3.0 pyparsing 2.4.7 python-apt 2.4.0+ubuntu2 python-dateutil 2.8.2 python-dotenv 1.0.1 python-json-logger 2.0.7 pytz 2024.1 PyYAML 6.0.1 pyzmq 25.1.2 qtconsole 5.5.1 QtPy 2.4.1 ray 2.9.2 referencing 0.33.0 regex 2023.12.25 requests 2.31.0 requests-oauthlib 1.3.1 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 rpds-py 0.17.1 rsa 4.9 safetensors 0.4.2 scipy 1.12.0 SecretStorage 3.3.1 Send2Trash 1.8.2 setuptools 59.6.0 six 1.16.0 smart-open 6.4.0 sniffio 1.3.0 soupsieve 2.5 stable-retro 0.9.3 /home/ka/github/stable-retro stack-data 0.6.3 starlette 0.32.0.post1 sympy 1.12 systemd-python 234 tensorboard 2.15.1 tensorboard-data-server 0.7.2 tensorboardX 2.6.2.2 tensorflow 2.15.0.post1 tensorflow-estimator 2.15.0 tensorflow-io-gcs-filesystem 0.35.0 tensorflow-probability 0.23.0 termcolor 2.4.0 terminado 0.18.0 tinycss2 1.2.1 tokenizers 0.15.1 tomli 2.0.1 torch 2.1.2+cu118 torchaudio 2.1.2+cu118 torchvision 0.16.2+cu118 tornado 6.4 tqdm 4.66.1 traitlets 5.14.1 transformers 4.37.2 triton 2.1.0 types-python-dateutil 2.8.19.20240106 typing_extensions 4.9.0 tzdata 2023.4 ubuntu-advantage-tools 8001 ufw 0.36.1 unattended-upgrades 0.1 uri-template 1.3.0 urllib3 2.2.0 uvicorn 0.27.0.post1 uvloop 0.19.0 virtualenv 20.25.0 wadllib 1.3.6 watchfiles 0.21.0 wcwidth 0.2.13 webcolors 1.13 webencodings 0.5.1 websocket-client 1.7.0 websockets 12.0 Werkzeug 3.0.1 wheel 0.37.1 widgetsnbextension 4.0.9 wrapt 1.14.1 yarl 1.9.4 zipp 1.0.0

Reproduction script

from ray import tune from ray.tune import Tuner from ray.rllib.algorithms.dreamerv3 import DreamerV3Config from ray.air import RunConfig, CheckpointConfig

config = (DreamerV3Config().environment("CartPole-v1").training(model_size="XS",training_ratio=1024,).resources(num_gpus=1, num_gpus_per_learner_worker=1, num_learner_workers=0))

tuner = Tuner( "DreamerV3", run_config=RunConfig( stop={"training_iteration": 4000}, checkpoint_config=CheckpointConfig(checkpoint_at_end=True) ), param_space=config, )

result = tuner.fit()

Issue Severity

High: It blocks me from completing my task.

apple1113 commented 9 months ago

I think there is some incompatibility between tesnorflow and pytorch 2.2 cuda 11.8 Did a clean install->tensorflow and tensorflow probability->dreamer working, but as soon as I also installed pytorch 2.2, it broke

aralcimcim commented 9 months ago

@apple1113, I had a similar problem, your code is working for me in Ubuntu with Python 3.9, these are my library versions:

gymnasium 0.28.1
nvidia-cuda-runtime-cu12 12.1.105
ray 2.9.2
tensorboard 2.15.2
tensorboardX 2.6.2.2
tensorflow 2.15.0.post1
tensorflow-probability 0.23.0
torch 2.2.0