pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.22k stars 292 forks source link

BRAX: Possible versioning issue #2183

Closed misterguick closed 3 months ago

misterguick commented 3 months ago

Describe the bug

I ran into 2 problems (that might be related) and a third one more that I would like to point your attention to

  1. brax.py seems to be using an older/different API that calls brax.envs.env.Env that doesn't (no longer ?) exist. It was not clear at all to me what version of what I should be using.
  2. https://github.com/pytorch/rl/issues/2184
  3. https://github.com/google/brax/issues/488

which I simply solved by modifying the line to if not isinstance(env, brax.envs.Env):

To Reproduce

Steps to reproduce the behavior.

import brax.envs
from torchrl.envs import BraxWrapper, BraxEnv
import torch

device = torch.device("cuda:0")
base_env = brax.envs.get_environment("halfcheetah")
env = BraxWrapper(base_env, device=device, requires_grad=True)
AttributeError                            Traceback (most recent call last)
Cell In [1], line 7
      5 device = torch.device("cuda:0")
      6 base_env = brax.envs.get_environment("halfcheetah")
----> 7 env = BraxWrapper(base_env, device=device, requires_grad=True)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torchrl/envs/common.py:175, in _EnvPostInit.__call__(cls, *args, **kwargs)
    173 auto_reset = kwargs.pop("auto_reset", False)
    174 auto_reset_replace = kwargs.pop("auto_reset_replace", True)
--> 175 instance: EnvBase = super().__call__(*args, **kwargs)
    176 # we create the done spec by adding a done/terminated entry if one is missing
    177 instance._create_done_specs()

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torchrl/envs/libs/brax.py:203, in BraxWrapper.__init__(self, env, categorical_action_encoding, **kwargs)
    201 self._seed_calls_reset = None
    202 self._categorical_action_encoding = categorical_action_encoding
--> 203 super().__init__(**kwargs)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torchrl/envs/common.py:3017, in _EnvWrapper.__init__(self, device, batch_size, allow_done_after_reset, *args, **kwargs)
   3014 self.wrapper_frame_skip = frame_skip
   3016 self._constructor_kwargs = kwargs
-> 3017 self._check_kwargs(kwargs)
   3018 self._env = self._build_env(**kwargs)  # writes the self._env attribute
   3019 self._make_specs(self._env)  # writes the self._env attribute

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torchrl/envs/libs/brax.py:211, in BraxWrapper._check_kwargs(self, kwargs)
    209     raise TypeError("Could not find environment key 'env' in kwargs.")
    210 env = kwargs["env"]
--> 211 if not isinstance(env, brax.envs.env.Env):
    212     raise TypeError("env is not of type 'brax.envs.env.Env'.")

AttributeError: module 'brax.envs' has no attribute 'env'

Expected behavior

/

Screenshots

/

System info

brax 0.10.4 pypi_0 pypi jax 0.4.28 pypi_0 pypi jax-cuda12-pjrt 0.4.28 pypi_0 pypi jax-cuda12-plugin 0.4.28 pypi_0 pypi jaxlib 0.4.28+cuda12.cudnn89 pypi_0 pypi jaxopt 0.8.3 pypi_0 pypi jaxtyping 0.2.29 pypi_0 pypi torch 2.3.0 pypi_0 pypi torch-activation 0.2.1 pypi_0 pypi torch-optimizer 0.3.0 pypi_0 pypi torchaudio 2.1.0.dev20230817+cu121 pypi_0 pypi torchmetrics 1.1.2 pypi_0 pypi torchrl 0.4.0 pypi_0 pypi torchvision 0.16.0.dev20230817+cu121 pypi_0 pypi

0.4.0 1.26.4 3.9.18 | packaged by conda-forge | (main, Dec 23 2023, 16:33:10) [GCC 12.3.0] linux

Additional context

Add any other context about the problem here.

Reason and Possible fixes

Modify line 211 in brax.py to

if not isinstance(env, brax.envs.Env):

Checklist

vmoens commented 3 months ago

We should version things better

Indeed it seems Brax removed the Env from envs.env and it's now in brax.envs Let me fix this