pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.27k stars 302 forks source link

[BUG] examples of brax in brax.py are failing on CUDA enabled machine #2319

Closed Jendker closed 2 months ago

Jendker commented 2 months ago

Describe the bug

The examples in https://github.com/pytorch/rl/blob/main/torchrl/envs/libs/brax.py are failing

To Reproduce

Steps to reproduce the behavior.

On a machine where CUDA is available, confirm with:

>>> import torch
>>> torch.cuda.is_available()
True

Run any example from the brax.py script.

E.g.

from torchrl.envs import BraxEnv
env = BraxEnv("ant")
env.set_seed(0)
td = env.reset()
td["action"] = env.action_spec.rand()
td = env.step(td)
print(td)

This results in

ValueError: Received incompatible devices for jitted computation. Got ARG_SHARDING with device ids [0] on platform GPU and ARG_SHARDING with device ids [0] on platform CPU

Expected behavior

No errors should appear.

System info

Describe the characteristic of your environment:

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

0.5.0+f840a1a 2.0.1 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0] linux

Installed with requirements.txt:

torch==2.3.1
torchvision==0.18.1
mujoco
mujoco-mjx
gymnasium
matplotlib
wandb
matplotlib
joblib
hydra-core
ipython
brax
jax[cuda12]==0.4.28

and calls:

$ cd path/to/root
$ git clone https://github.com/pytorch/tensordict
$ git clone https://github.com/pytorch/rl
$ cd tensordict
$ python setup.py develop
$ cd ../rl
$ python setup.py develop

Checklist