ValueError: Received incompatible devices for jitted computation. Got ARG_SHARDING with device ids [0] on platform GPU and ARG_SHARDING with device ids [0] on platform CPU
Expected behavior
No errors should appear.
System info
Describe the characteristic of your environment:
Describe how the library was installed (pip, source, ...)
Describe the bug
The examples in https://github.com/pytorch/rl/blob/main/torchrl/envs/libs/brax.py are failing
To Reproduce
Steps to reproduce the behavior.
On a machine where CUDA is available, confirm with:
Run any example from the brax.py script.
E.g.
This results in
Expected behavior
No errors should appear.
System info
Describe the characteristic of your environment:
0.5.0+f840a1a 2.0.1 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0] linux
Installed with requirements.txt:
and calls:
Checklist