Open cvoelcker opened 2 years ago
Hey there @cvoelcker ! (I'm not a Brax maintainer btw, I just wrote the jax2torch conversion functions).
I think this is some strange memory allocation problem that happens when torch (internally?) duplicates the Tensor to compute the ensembled predictions, but that is pure conjecture.
The Jax->Torch conversion is done via their respective dlpack
interfaces, which essentially makes it possible to read the same tensor from either Jax or Torch, in-place, on the GPU.
The conversions are done:
This seems a bit tough to debug without any code though. Would you mind sharing some of the code that does the "ensembled predictions"? I've never personally seen this kind of error before. My intuition tells me there is probably some kind of in-place modification of the tensor happening on the torch side, and that this might be causing this weird issue. Either that, or the memory for the tensor is pinned, and there's some funny allocation business going on. I'm no expert though, so more code/context would definitely help!
Hi, sorry for the late reply, I tried to isolate the error without having to share my whole codebase, but to no avail, it is very difficult to reproduce in a simplified codebase. The ensembled predictions are made using the EnsembleLinearLayer from mbrl-lib. The action_selection code seems to be the culprit, it is in bayesian_daml/sac/sunrise_sac (see attachement).
def select_action(self, state, evaluate=False, num_candidates=5, idx=None):
candidate_actions = []
for i in range(num_candidates):
action, _, _ = self.policy.sample(state)
candidate_actions.append(action)
candidate_actions = torch.cat(candidate_actions, dim=0)
q1, q2 = self.critic_target(
state.expand(len(candidate_actions), -1),
candidate_actions.squeeze(),
)
q = 0.5 * (q1 + q2)
ucb_var = torch.std(q, dim=0)
q = torch.mean(q, dim=0)
max_action_idx = torch.argmax(q + 5.0 * ucb_var)
action = candidate_actions[max_action_idx].detach()
return action
Wrapping action in torch.clone
fixes the error interestingly enough, so it might be an issue with the action, not the state.
Both policy and critic_target are nn.Modules using EnsembleLinearLayer.
To reproduce, run python bayesian_daml/sac/main.py env_name=brax-hopper algorithm=sunrise_sac start_steps=0
Wrapping action in torch.clone fixes the error interestingly enough,
Then it sounds like what I was suspecting might be correct.
Hi, I have an extremely strange bug that is slightly hard to reproduce :) I am trying to integrate brax with SAC and want to use an ensemble version of the algorithm. When I use the ensemble, the brax environments (tested so far on halfcheetah, walker2d, grasp and hopper) raise the following error. The original SAC version doe not! In addition, the ant environment for some reason works as well.
I am wrapping the torch code in both cases with torch.no_grad() and detach from the tensor. The shapes going into brax are the same in either case. I think this is some strange memory allocation problem that happens when torch (internally?) duplicates the Tensor to compute the ensembled predictions, but that is pure conjecture.
Would be very happy about any hints.