Closed misterguick closed 4 months ago
On it!
I successfully ran this script, which is a very mild modification of yours, with torchrl on the main branch (same for tensordict)
import brax.envs
from torchrl.envs import BraxWrapper, BraxEnv
import torch
from tensordict.nn import TensorDictModule
from tensordict.tensordict import TensorDict
from torch import nn
device = torch.device("cuda:0")
# base_env = brax.envs.get_environment("halfcheetah")
# This can be uncommented too
# env = BraxWrapper(base_env, device=device, requires_grad=True)
env = BraxEnv("halfcheetah", device=device, requires_grad=True, batch_size=[10])
env.set_seed(0)
# Reset the environment to obtain the initial state
td = env.reset()
# Define a simple policy network
policy = TensorDictModule(nn.Linear(17, 1).to(device), in_keys=["observation"], out_keys=["action"])
# Perform a rollout using the policy
td = env.rollout(10, policy)
print('rollout', td)
# Backpropagate through the rollout and optimize the policy
td["next", "reward"].mean().backward(retain_graph=False)
print("grad norm", torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0))
Versions
torch 2.4.0.dev20240526+cu121
torchrl 0.4.0+259f20d /home/vmoens/rl
tensordict 0.4.0+ca23038 /home/vmoens/tensordict
brax 0.10.4
jax 0.4.28
jax-cuda12-pjrt 0.4.28
jax-cuda12-plugin 0.4.28
jax-jumpy 1.0.0
jaxlib 0.4.28
jaxopt 0.8.3
Happy to reopen if the issue persists!
Thank you for your reply !
The code you provided tests on a batch size of 10. My problem occurs when batch size is set to default or 1 (different error). Your code works but no longer does when setting batch_size to default or 1.
Thank you again in advance !
I can reproduce it and it looks like a rabbit hole! Bear with me while I try to fix it
Hi ! Any update on this ?
Describe the bug
With Brax, using default batch size or batch size one breaks.
This issue might be related to
To Reproduce
With default batch size
With batch size = [1]
Expected behavior
This is basically the example from the doc with device being GPU (https://pytorch.org/rl/stable/reference/generated/torchrl.envs.BraxWrapper.html)
Screenshots
/
System info
brax 0.10.4 pypi_0 pypi jax 0.4.28 pypi_0 pypi jax-cuda12-pjrt 0.4.28 pypi_0 pypi jax-cuda12-plugin 0.4.28 pypi_0 pypi jaxlib 0.4.28+cuda12.cudnn89 pypi_0 pypi jaxopt 0.8.3 pypi_0 pypi jaxtyping 0.2.29 pypi_0 pypi torch 2.3.0 pypi_0 pypi torch-activation 0.2.1 pypi_0 pypi torch-optimizer 0.3.0 pypi_0 pypi torchaudio 2.1.0.dev20230817+cu121 pypi_0 pypi torchmetrics 1.1.2 pypi_0 pypi torchrl 0.4.0 pypi_0 pypi torchvision 0.16.0.dev20230817+cu121 pypi_0 pypi
0.4.0 1.26.4 3.9.18 | packaged by conda-forge | (main, Dec 23 2023, 16:33:10) [GCC 12.3.0] linux
Additional context
/
Reason and Possible fixes
Running on CPU or batch size [n] with n > 1
Checklist