Open apple1113 opened 9 months ago
I think there is some incompatibility between tesnorflow and pytorch 2.2 cuda 11.8 Did a clean install->tensorflow and tensorflow probability->dreamer working, but as soon as I also installed pytorch 2.2, it broke
@apple1113, I had a similar problem, your code is working for me in Ubuntu with Python 3.9, these are my library versions:
gymnasium 0.28.1
nvidia-cuda-runtime-cu12 12.1.105
ray 2.9.2
tensorboard 2.15.2
tensorflow 2.15.0.post1
tensorflow-probability 0.23.0
torch 2.2.0
What happened + What you expected to happen
Hi, I am trying to check out dreamerv3 but I encounter a weird bug that only happens when I attempt to set num_gpus_per_learner_worker=1
cuDNN launch failure : input shape ([1,1,256,1]) [[{{node mlp/layer_normalization_3/FusedBatchNormV3}}]] [Op:__inference_forward_train_24736]
Call arguments received by layer 'dreamer_model' (type DreamerModel): • inputs=None • observations=tf.Tensor(shape=(1, 64, 4), dtype=float32) • actions=tf.Tensor(shape=(1, 64, 2), dtype=float32) • is_first=tf.Tensor(shape=(1, 64), dtype=bool) • start_is_terminated_BxT=tf.Tensor(shape=(64,), dtype=bool) 2024-02-07 19:25:07,133 ERROR -- Trials did not complete: [DreamerV3_CartPole-v1_87b63_00000]
I have done my research and has tensorflow_probability installed. This error doesn't occur when I use CPU only, any ideas why?
Versions / Dependencies
absl-py 2.1.0 accelerate 0.26.1 aiohttp 3.9.3 aiohttp-cors 0.7.0 aiorwlock 1.4.0 aiosignal 1.3.1 annotated-types 0.6.0 anyio 4.2.0 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 arrow 1.3.0 asttokens 2.4.1 astunparse 1.6.3 async-lru 2.0.4 async-timeout 4.0.3 attrs 23.2.0 Babel 2.14.0 beautifulsoup4 4.12.3 bleach 6.1.0 blessed 1.20.0 blinker 1.4 cachetools 5.3.2 certifi 2024.2.2 cffi 1.16.0 charset-normalizer 3.3.2 click 8.1.7 cloudpickle 3.0.0 colorful 0.5.6 comm 0.2.1 command-not-found 0.3 contextlib2 21.6.0 cryptography 3.4.8 dbus-python 1.2.18 debugpy 1.8.0 decorator 5.1.1 defusedxml 0.7.1 distlib 0.3.8 distro 1.7.0 distro-info 1.1+ubuntu0.1 dm-tree 0.1.8 exceptiongroup 1.2.0 executing 2.0.1 Farama-Notifications 0.0.4 fastapi 0.108.0 fastjsonschema 2.19.1 filelock 3.13.1 flatbuffers 23.5.26 fqdn 1.5.1 frozenlist 1.4.1 fsspec 2024.2.0 gast 0.5.4 google-api-core 2.16.2 google-auth 2.27.0 google-auth-oauthlib 1.2.0 google-pasta 0.2.0 googleapis-common-protos 1.62.0 gpustat 1.1.1 grpcio 1.60.1 gymnasium 0.29.1 h11 0.14.0 h5py 3.10.0 httpcore 1.0.2 httplib2 0.20.2 httptools 0.6.1 httpx 0.26.0 huggingface-hub 0.20.3 idna 3.6 importlib-metadata 4.6.4 ipykernel 6.29.0 ipython 8.21.0 ipython_genutils 0.2.0 ipywidgets 8.1.1 isoduration 20.11.0 jedi 0.19.1 jeepney 0.7.1 Jinja2 3.1.3 json5 0.9.14 jsonpointer 2.4 jsonschema 4.21.1 jsonschema-specifications 2023.12.1 jupyter 1.0.0 jupyter_client 8.6.0 jupyter-console 6.6.3 jupyter_core 5.7.1 jupyter-events 0.9.0 jupyter-lsp 2.2.2 jupyter_server 2.12.5 jupyter_server_terminals 0.5.2 jupyterlab 4.1.0 jupyterlab_pygments 0.3.0 jupyterlab_server 2.25.2 jupyterlab-widgets 3.0.9 keras 2.15.0 keyring 23.5.0 launchpadlib 1.10.16 lazr.restfulclient 0.14.4 lazr.uri 1.0.6 libclang 16.0.6 lz4 4.3.3 Markdown 3.5.2 MarkupSafe 2.1.5 matplotlib-inline 0.1.6 mistune 3.0.2 ml-collections 0.1.1 ml-dtypes 0.2.0 more-itertools 8.10.0 mpmath 1.3.0 msgpack 1.0.7 multidict 6.0.5 nbclient 0.9.0 nbconvert 7.14.2 nbformat 5.9.2 nest-asyncio 1.6.0 netifaces 0.11.0 networkx 3.2.1 notebook 7.0.7 notebook_shim 0.2.3 numpy 1.26.4 nvidia-cublas-cu11 nvidia-cublas-cu12 nvidia-cuda-cupti-cu11 11.8.87 nvidia-cuda-cupti-cu12 12.2.142 nvidia-cuda-nvcc-cu12 12.2.140 nvidia-cuda-nvrtc-cu11 11.8.89 nvidia-cuda-nvrtc-cu12 12.2.140 nvidia-cuda-runtime-cu11 11.8.89 nvidia-cuda-runtime-cu12 12.2.140 nvidia-cudnn-cu11 nvidia-cudnn-cu12 nvidia-cufft-cu11 nvidia-cufft-cu12 nvidia-curand-cu11 nvidia-curand-cu12 nvidia-cusolver-cu11 nvidia-cusolver-cu12 nvidia-cusparse-cu11 nvidia-cusparse-cu12 nvidia-ml-py 12.535.133 nvidia-nccl-cu11 2.19.3 nvidia-nccl-cu12 2.16.5 nvidia-nvjitlink-cu12 12.2.140 nvidia-nvtx-cu11 11.8.86 oauthlib 3.2.0 opencensus 0.11.4 opencensus-context 0.1.3 opencv-python opt-einsum 3.3.0 overrides 7.7.0 packaging 23.2 pandas 2.2.0 pandocfilters 1.5.1 parso 0.8.3 pexpect 4.9.0 pillow 10.2.0 pip 24.0 platformdirs 4.2.0 prometheus-client 0.19.0 prompt-toolkit 3.0.43 protobuf 4.23.4 psutil 5.9.8 ptyprocess 0.7.0 pure-eval 0.2.2 py-spy 0.3.14 pyarrow 15.0.0 pyasn1 0.5.1 pyasn1-modules 0.3.0 pycparser 2.21 pydantic 2.6.1 pydantic_core 2.16.2 pyglet 1.5.28 Pygments 2.17.2 PyGObject 3.42.1 PyJWT 2.3.0 pyparsing 2.4.7 python-apt 2.4.0+ubuntu2 python-dateutil 2.8.2 python-dotenv 1.0.1 python-json-logger 2.0.7 pytz 2024.1 PyYAML 6.0.1 pyzmq 25.1.2 qtconsole 5.5.1 QtPy 2.4.1 ray 2.9.2 referencing 0.33.0 regex 2023.12.25 requests 2.31.0 requests-oauthlib 1.3.1 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 rpds-py 0.17.1 rsa 4.9 safetensors 0.4.2 scipy 1.12.0 SecretStorage 3.3.1 Send2Trash 1.8.2 setuptools 59.6.0 six 1.16.0 smart-open 6.4.0 sniffio 1.3.0 soupsieve 2.5 stable-retro 0.9.3 /home/ka/github/stable-retro stack-data 0.6.3 starlette 0.32.0.post1 sympy 1.12 systemd-python 234 tensorboard 2.15.1 tensorboard-data-server 0.7.2 tensorboardX tensorflow 2.15.0.post1 tensorflow-estimator 2.15.0 tensorflow-io-gcs-filesystem 0.35.0 tensorflow-probability 0.23.0 termcolor 2.4.0 terminado 0.18.0 tinycss2 1.2.1 tokenizers 0.15.1 tomli 2.0.1 torch 2.1.2+cu118 torchaudio 2.1.2+cu118 torchvision 0.16.2+cu118 tornado 6.4 tqdm 4.66.1 traitlets 5.14.1 transformers 4.37.2 triton 2.1.0 types-python-dateutil typing_extensions 4.9.0 tzdata 2023.4 ubuntu-advantage-tools 8001 ufw 0.36.1 unattended-upgrades 0.1 uri-template 1.3.0 urllib3 2.2.0 uvicorn 0.27.0.post1 uvloop 0.19.0 virtualenv 20.25.0 wadllib 1.3.6 watchfiles 0.21.0 wcwidth 0.2.13 webcolors 1.13 webencodings 0.5.1 websocket-client 1.7.0 websockets 12.0 Werkzeug 3.0.1 wheel 0.37.1 widgetsnbextension 4.0.9 wrapt 1.14.1 yarl 1.9.4 zipp 1.0.0
Reproduction script
from ray import tune from ray.tune import Tuner from ray.rllib.algorithms.dreamerv3 import DreamerV3Config from ray.air import RunConfig, CheckpointConfig
config = (DreamerV3Config().environment("CartPole-v1").training(model_size="XS",training_ratio=1024,).resources(num_gpus=1, num_gpus_per_learner_worker=1, num_learner_workers=0))
tuner = Tuner( "DreamerV3", run_config=RunConfig( stop={"training_iteration": 4000}, checkpoint_config=CheckpointConfig(checkpoint_at_end=True) ), param_space=config, )
result =
Issue Severity
High: It blocks me from completing my task.