pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.33k stars 307 forks source link

[BUG] BRAX: batch size issue #2184

Closed misterguick closed 4 months ago

misterguick commented 5 months ago

Describe the bug

With Brax, using default batch size or batch size one breaks.

  1. Using default batch size gives on error
  2. Using batch size = [1] gives another error

This issue might be related to

  1. https://github.com/pytorch/rl/issues/2183
  2. https://github.com/google/brax/issues/488

To Reproduce

With default batch size

import brax.envs
from torchrl.envs import BraxWrapper, BraxEnv
import torch
from tensordict.nn import TensorDictModule
from tensordict.tensordict import TensorDict
from torch import nn

device = torch.device("cuda:0")
base_env = brax.envs.get_environment("halfcheetah")
env = BraxWrapper(base_env, device=device, requires_grad=True)

env.set_seed(0)

# Reset the environment to obtain the initial state
td = env.reset()

# Define a simple policy network
policy = TensorDictModule(nn.Linear(17, 1).to(device), in_keys=["observation"], out_keys=["action"])

# Perform a rollout using the policy
td = env.rollout(10, policy)

# Backpropagate through the rollout and optimize the policy
td["next", "reward"].mean().backward(retain_graph=False)
  link = jax.tree_map(lambda x: x[1:].copy(), link)
/users/psi/bdebes/miniconda3/envs/remote/lib/python3.9/site-packages/brax/scan.py:50: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  return jax.tree_map(take, obj)
/users/psi/bdebes/miniconda3/envs/remote/lib/python3.9/site-packages/brax/generalized/dynamics.py:83: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  motion = jax.tree_map(lambda *x: jp.column_stack(x), *jds).reshape((-1, 3))
/users/psi/bdebes/miniconda3/envs/remote/lib/python3.9/site-packages/brax/generalized/constraint.py:175: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  return jax.tree_map(jp.concatenate, jax.vmap(row_fn)(c))
/users/psi/bdebes/miniconda3/envs/remote/lib/python3.9/site-packages/brax/scan.py:50: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  return jax.tree_map(take, obj)
/users/psi/bdebes/miniconda3/envs/remote/lib/python3.9/site-packages/brax/scan.py:50: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  return jax.tree_map(take, obj)
/users/psi/bdebes/miniconda3/envs/remote/lib/python3.9/site-packages/brax/generalized/dynamics.py:83: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  motion = jax.tree_map(lambda *x: jp.column_stack(x), *jds).reshape((-1, 3))
/users/psi/bdebes/miniconda3/envs/remote/lib/python3.9/site-packages/brax/generalized/constraint.py:175: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  return jax.tree_map(jp.concatenate, jax.vmap(row_fn)(c))
/users/psi/bdebes/miniconda3/envs/remote/lib/python3.9/site-packages/brax/scan.py:50: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  return jax.tree_map(take, obj)
E0531 01:17:51.721701 2316756 pjrt_stream_executor_client.cc:2826] Execution of replica 0 failed: INTERNAL: Address of buffer 1 must be a multiple of 10, but was 0x14a62da00824
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In [4], line 24
     21 td = env.rollout(10, policy)
     23 # Backpropagate through the rollout and optimize the policy
---> 24 td["next", "reward"].mean().backward(retain_graph=False)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torch/_tensor.py:525, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    515 if has_torch_function_unary(self):
    516     return handle_torch_function(
    517         Tensor.backward,
    518         (self,),
   (...)
    523         inputs=inputs,
    524     )
--> 525 torch.autograd.backward(
    526     self, gradient, retain_graph, create_graph, inputs=inputs
    527 )

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torch/autograd/__init__.py:267, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    262     retain_graph = create_graph
    264 # The reason we repeat the same comment below is that
    265 # some Python versions print out the first line of a multi-line function
    266 # calls in the traceback and some print out the last line
--> 267 _engine_run_backward(
    268     tensors,
    269     grad_tensors_,
    270     retain_graph,
    271     create_graph,
    272     inputs,
    273     allow_unreachable=True,
    274     accumulate_grad=True,
    275 )

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torch/autograd/graph.py:744, in _engine_run_backward(t_outputs, *args, **kwargs)
    742     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    743 try:
--> 744     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    745         t_outputs, *args, **kwargs
    746     )  # Calls into the C++ engine to run the backward pass
    747 finally:
    748     if attach_logging_hooks:

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torch/autograd/function.py:301, in BackwardCFunction.apply(self, *args)
    295     raise RuntimeError(
    296         "Implementing both 'backward' and 'vjp' for a custom "
    297         "Function is not allowed. You should only implement one "
    298         "of them."
    299     )
    300 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 301 return user_fn(self, *args)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torchrl/envs/libs/brax.py:648, in _BraxEnvStep.backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values)
    635 grad_next_state_td = TensorDict(
    636     source={
    637         "pipeline_state": pipeline_state,
   (...)
    645     batch_size=ctx.env.batch_size,
    646 )
    647 # convert tensors to ndarrays
--> 648 grad_next_state_obj = _tensordict_to_object(
    649     grad_next_state_td, ctx.env._state_example
    650 )
    652 # flatten batch size
    653 grad_next_state_flat = _tree_flatten(grad_next_state_obj, ctx.env.batch_size)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torchrl/envs/libs/jax_utils.py:129, in _tensordict_to_object(tensordict, object_example)
    127             value = value.to(torch.uint8)
    128         value = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value.contiguous()))
--> 129         t[name] = value.reshape(example.shape).view(example.dtype)
    130 return type(object_example)(**t)

    [... skipping hidden 1 frame]

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/jax/_src/lax/lax.py:892, in reshape(operand, new_sizes, dimensions)
    889 else:
    890   dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
--> 892   return reshape_p.bind(
    893     operand, *dyn_shape, new_sizes=tuple(static_new_sizes),
    894     dimensions=None if dims is None or same_dims else dims)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/jax/_src/core.py:387, in Primitive.bind(self, *args, **params)
    384 def bind(self, *args, **params):
    385   assert (not config.enable_checks.value or
    386           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 387   return self.bind_with_trace(find_top_trace(args), args, params)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/jax/_src/core.py:391, in Primitive.bind_with_trace(self, trace, args, params)
    389 def bind_with_trace(self, trace, args, params):
    390   with pop_level(trace.level):
--> 391     out = trace.process_primitive(self, map(trace.full_raise, args), params)
    392   return map(full_lower, out) if self.multiple_results else full_lower(out)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/jax/_src/core.py:879, in EvalTrace.process_primitive(self, primitive, tracers, params)
    877   return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params)
    878 else:
--> 879   return primitive.impl(*tracers, **params)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/jax/_src/dispatch.py:86, in apply_primitive(prim, *args, **params)
     84 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     85 try:
---> 86   outs = fun(*args)
     87 finally:
     88   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

ValueError: INTERNAL: Address of buffer 1 must be a multiple of 10, but was 0x14a62da00824

With batch size = [1]

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In [1], line 21
     18 policy = TensorDictModule(nn.Linear(17, 1).to(device), in_keys=["observation"], out_keys=["action"])
     20 # Perform a rollout using the policy
---> 21 td = env.rollout(10, policy)
     23 # Backpropagate through the rollout and optimize the policy
     24 td["next", "reward"].mean().backward(retain_graph=False)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torchrl/envs/common.py:2562, in EnvBase.rollout(self, max_steps, policy, callback, auto_reset, auto_cast_to_device, break_when_any_done, return_contiguous, tensordict, set_truncated, out)
   2552 kwargs = {
   2553     "tensordict": tensordict,
   2554     "auto_cast_to_device": auto_cast_to_device,
   (...)
   2559     "callback": callback,
   2560 }
   2561 if break_when_any_done:
-> 2562     tensordicts = self._rollout_stop_early(**kwargs)
   2563 else:
   2564     tensordicts = self._rollout_nonstop(**kwargs)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torchrl/envs/common.py:2638, in EnvBase._rollout_stop_early(self, tensordict, auto_cast_to_device, max_steps, policy, policy_device, env_device, callback)
   2636     else:
   2637         tensordict.clear_device_()
-> 2638 tensordict = self.step(tensordict)
   2639 td_append = tensordict.copy()
   2640 tensordicts.append(td_append)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torchrl/envs/common.py:1461, in EnvBase.step(self, tensordict)
   1458 self._assert_tensordict_shape(tensordict)
   1459 next_preset = tensordict.get("next", None)
-> 1461 next_tensordict = self._step(tensordict)
   1462 next_tensordict = self._step_proc_data(next_tensordict)
   1463 if next_preset is not None:
   1464     # tensordict could already have a "next" key
   1465     # this could be done more efficiently by not excluding but just passing
   1466     # the necessary keys

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torchrl/envs/libs/brax.py:410, in BraxWrapper._step(self, tensordict)
    404 def _step(
    405     self,
    406     tensordict: TensorDictBase,
    407 ) -> TensorDictBase:
    409     if self.requires_grad:
--> 410         out = self._step_with_grad(tensordict)
    411     else:
    412         out = self._step_without_grad(tensordict)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torchrl/envs/libs/brax.py:373, in BraxWrapper._step_with_grad(self, tensordict)
    370 qp_keys, qp_values = zip(*state.get("pipeline_state").items())
    372 # call env step with autograd function
--> 373 next_state_nograd, next_obs, next_reward, *next_qp_values = _BraxEnvStep.apply(
    374     self, state, action, *qp_values
    375 )
    377 # extract done values: we assume a shape identical to reward
    378 next_done = next_state_nograd.get("done").view(*self.reward_spec.shape)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torch/autograd/function.py:598, in Function.apply(cls, *args, **kwargs)
    595 if not torch._C._are_functorch_transforms_active():
    596     # See NOTE: [functorch vjp and autograd interaction]
    597     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 598     return super().apply(*args, **kwargs)  # type: ignore[misc]
    600 if not is_setup_ctx_defined:
    601     raise RuntimeError(
    602         "In order to use an autograd.Function with functorch transforms "
    603         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    604         "staticmethod. For more details, please see "
    605         "https://pytorch.org/docs/master/notes/extending.func.html"
    606     )

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torchrl/envs/libs/brax.py:582, in _BraxEnvStep.forward(ctx, env, state_td, action_tensor, *qp_values)
    579 import jax
    581 # convert tensors to ndarrays
--> 582 state_obj = _tensordict_to_object(state_td, env._state_example)
    583 action_nd = _tensor_to_ndarray(action_tensor)
    585 # flatten batch size

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torchrl/envs/libs/jax_utils.py:119, in _tensordict_to_object(tensordict, object_example)
    117 value = tensordict.get(name, None)
    118 if isinstance(value, TensorDictBase):
--> 119     t[name] = _tensordict_to_object(value, example)
    120 elif value is None:
    121     if isinstance(example, dict):

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torchrl/envs/libs/jax_utils.py:119, in _tensordict_to_object(tensordict, object_example)
    117 value = tensordict.get(name, None)
    118 if isinstance(value, TensorDictBase):
--> 119     t[name] = _tensordict_to_object(value, example)
    120 elif value is None:
    121     if isinstance(example, dict):

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/torchrl/envs/libs/jax_utils.py:128, in _tensordict_to_object(tensordict, object_example)
    126         if value.dtype is torch.bool:
    127             value = value.to(torch.uint8)
--> 128         value = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value.contiguous()))
    129         t[name] = value.reshape(example.shape).view(example.dtype)
    130 return type(object_example)(**t)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/jax/_src/dlpack.py:278, in from_dlpack(external_array, device, copy)
    275   return _from_dlpack(external_array, device, copy)
    277 # Legacy path
--> 278 return _legacy_from_dlpack(external_array, device, copy)

File ~/miniconda3/envs/remote/lib/python3.9/site-packages/jax/_src/dlpack.py:195, in _legacy_from_dlpack(dlpack, device, copy)
    192     except RuntimeError:
    193       pass
--> 195 _arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
    196     dlpack, cpu_backend, gpu_backend)) # type: ignore
    197 dlpack_device, = _arr.devices()
    198 return _place_array(_arr, device, dlpack_device, copy)

XlaRuntimeError: UNIMPLEMENTED: from_dlpack got array with non-default layout with minor-to-major dimensions (2,0,1), expected (2,1,0)

Expected behavior

This is basically the example from the doc with device being GPU (https://pytorch.org/rl/stable/reference/generated/torchrl.envs.BraxWrapper.html)

Screenshots

/

System info

brax 0.10.4 pypi_0 pypi jax 0.4.28 pypi_0 pypi jax-cuda12-pjrt 0.4.28 pypi_0 pypi jax-cuda12-plugin 0.4.28 pypi_0 pypi jaxlib 0.4.28+cuda12.cudnn89 pypi_0 pypi jaxopt 0.8.3 pypi_0 pypi jaxtyping 0.2.29 pypi_0 pypi torch 2.3.0 pypi_0 pypi torch-activation 0.2.1 pypi_0 pypi torch-optimizer 0.3.0 pypi_0 pypi torchaudio 2.1.0.dev20230817+cu121 pypi_0 pypi torchmetrics 1.1.2 pypi_0 pypi torchrl 0.4.0 pypi_0 pypi torchvision 0.16.0.dev20230817+cu121 pypi_0 pypi

0.4.0 1.26.4 3.9.18 | packaged by conda-forge | (main, Dec 23 2023, 16:33:10) [GCC 12.3.0] linux

Additional context

/

Reason and Possible fixes

Running on CPU or batch size [n] with n > 1

Checklist

vmoens commented 5 months ago

On it!

vmoens commented 5 months ago

I successfully ran this script, which is a very mild modification of yours, with torchrl on the main branch (same for tensordict)

import brax.envs
from torchrl.envs import BraxWrapper, BraxEnv
import torch
from tensordict.nn import TensorDictModule
from tensordict.tensordict import TensorDict
from torch import nn

device = torch.device("cuda:0")
# base_env = brax.envs.get_environment("halfcheetah")

# This can be uncommented too
# env = BraxWrapper(base_env, device=device, requires_grad=True)
env = BraxEnv("halfcheetah", device=device, requires_grad=True, batch_size=[10])

env.set_seed(0)

# Reset the environment to obtain the initial state
td = env.reset()

# Define a simple policy network
policy = TensorDictModule(nn.Linear(17, 1).to(device), in_keys=["observation"], out_keys=["action"])

# Perform a rollout using the policy
td = env.rollout(10, policy)
print('rollout', td)

# Backpropagate through the rollout and optimize the policy
td["next", "reward"].mean().backward(retain_graph=False)

print("grad norm", torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0))

Versions

torch                      2.4.0.dev20240526+cu121
torchrl                    0.4.0+259f20d            /home/vmoens/rl
tensordict                 0.4.0+ca23038            /home/vmoens/tensordict
brax                       0.10.4
jax                        0.4.28
jax-cuda12-pjrt            0.4.28
jax-cuda12-plugin          0.4.28
jax-jumpy                  1.0.0
jaxlib                     0.4.28
jaxopt                     0.8.3

Happy to reopen if the issue persists!

misterguick commented 5 months ago

Thank you for your reply !

The code you provided tests on a batch size of 10. My problem occurs when batch size is set to default or 1 (different error). Your code works but no longer does when setting batch_size to default or 1.

Thank you again in advance !

vmoens commented 5 months ago

I can reproduce it and it looks like a rabbit hole! Bear with me while I try to fix it

misterguick commented 4 months ago

Hi ! Any update on this ?