google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.23k stars 246 forks source link

Memory error on some environments when using ensemble methods #201

Open cvoelcker opened 2 years ago

cvoelcker commented 2 years ago

Hi, I have an extremely strange bug that is slightly hard to reproduce :) I am trying to integrate brax with SAC and want to use an ensemble version of the algorithm. When I use the ensemble, the brax environments (tested so far on halfcheetah, walker2d, grasp and hopper) raise the following error. The original SAC version doe not! In addition, the ant environment for some reason works as well.

I am wrapping the torch code in both cases with torch.no_grad() and detach from the tensor. The shapes going into brax are the same in either case. I think this is some strange memory allocation problem that happens when torch (internally?) duplicates the Tensor to compute the ensembled predictions, but that is pure conjecture.

Would be very happy about any hints.

Traceback (most recent call last):
   File "sac/main.py", line 80, in main
    next_state, reward, done, _ = env.step(action)  # Step
  File "sac/env/wrappers.py", line 104, in step
    return self.env.step(action)
  File "brax/envs/to_torch.py", line 59, in step
    obs, reward, done, info = super().step(action)
  File "lib/python3.10/site-packages/gym/core.py", line 289, in step
    return self.env.step(action)
  File "python3.10/site-packages/gym/wrappers/order_enforcing.py", line 11, in step
    observation, reward, done, info = self.env.step(action)
  File "python3.10/site-packages/brax/envs/wrappers.py", line 283, in step
    self._state, obs, reward, done, info = self._step(self._state, action)
ValueError: INTERNAL: Address of buffer 11 must be a multiple of 10, but was 0x7f588c3d0848
lebrice commented 2 years ago

Hey there @cvoelcker ! (I'm not a Brax maintainer btw, I just wrote the jax2torch conversion functions).

I think this is some strange memory allocation problem that happens when torch (internally?) duplicates the Tensor to compute the ensembled predictions, but that is pure conjecture.

The Jax->Torch conversion is done via their respective dlpack interfaces, which essentially makes it possible to read the same tensor from either Jax or Torch, in-place, on the GPU.

The conversions are done:

This seems a bit tough to debug without any code though. Would you mind sharing some of the code that does the "ensembled predictions"? I've never personally seen this kind of error before. My intuition tells me there is probably some kind of in-place modification of the tensor happening on the torch side, and that this might be causing this weird issue. Either that, or the memory for the tensor is pinned, and there's some funny allocation business going on. I'm no expert though, so more code/context would definitely help!

cvoelcker commented 2 years ago

Hi, sorry for the late reply, I tried to isolate the error without having to share my whole codebase, but to no avail, it is very difficult to reproduce in a simplified codebase. The ensembled predictions are made using the EnsembleLinearLayer from mbrl-lib. The action_selection code seems to be the culprit, it is in bayesian_daml/sac/sunrise_sac (see attachement).

def select_action(self, state, evaluate=False, num_candidates=5, idx=None):
     candidate_actions = []
     for i in range(num_candidates):
         action, _, _ = self.policy.sample(state)
         candidate_actions.append(action)
          candidate_actions = torch.cat(candidate_actions, dim=0)
          q1, q2 = self.critic_target(

   state.expand(len(candidate_actions), -1),
                    candidate_actions.squeeze(),
          )
          q = 0.5 * (q1 + q2)
          ucb_var = torch.std(q, dim=0)
          q = torch.mean(q, dim=0)
          max_action_idx = torch.argmax(q + 5.0 * ucb_var)
          action = candidate_actions[max_action_idx].detach()
          return action

Wrapping action in torch.clone fixes the error interestingly enough, so it might be an issue with the action, not the state.

Both policy and critic_target are nn.Modules using EnsembleLinearLayer.

To reproduce, run python bayesian_daml/sac/main.py env_name=brax-hopper algorithm=sunrise_sac start_steps=0

sunrise.zip

lebrice commented 2 years ago

Wrapping action in torch.clone fixes the error interestingly enough, 

Then it sounds like what I was suspecting might be correct.