leggedrobotics / rsl_rl

Fast and simple implementation of RL algorithms, designed to run fully on GPU.
Other
505 stars 156 forks source link

Fixed the problem of the reset function of Memory corresponding to actor_critic_recurrent #35

Open thkkk opened 1 month ago

thkkk commented 1 month ago

Original code in class Memory(torch.nn.Module): at rsl_rl/modules/actor_critic_recurrent.py

    def reset(self, dones=None):
        # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
        for hidden_state in self.hidden_states:
            hidden_state[..., dones, :] = 0.0

When I train PPO policy with num_envs=1 using ActorCriticRecurrent, I find a bug:

../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [127,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
Traceback (most recent call last):
  File "rsl_rl/rsl_rl/runners/on_policy_runner.py", line 124, in learn
    self.alg.process_env_step(rewards, dones, infos)
  File "rsl_rl/rsl_rl/algorithms/ppo.py", line 95, in process_env_step
    self.actor_critic.reset(dones)
  File "rsl_rl/rsl_rl/modules/actor_critic_recurrent.py", line 54, in reset
    self.memory_a.reset(dones)
  File "rsl_rl/rsl_rl/modules/actor_critic_recurrent.py", line 102, in reset
    hidden_state[..., dones, :] = 0.0  # dones_envs_id
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

I modify the code and find that dones.max() >= hidden_state.size(-2)

    def reset(self, dones=None):
        # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
        # dones: (num_envs,), hidden_states: (num_layers, num_envs, hidden_size)
        print(f"dones: {dones},  hidden: {self.hidden_states[0].shape}")
        for hidden_state in self.hidden_states:
            assert dones.max() < hidden_state.size(-2), f"dones {dones} index out of range {hidden_state.shape}"
            hidden_state[..., dones, :] = 0.0

The logs( num_envs=1) are below:

dones: tensor([0], device='cuda:0'),  hidden: torch.Size([2, 1, 256])
dones: tensor([0], device='cuda:0'),  hidden: torch.Size([2, 1, 256])
dones: tensor([0], device='cuda:0'),  hidden: torch.Size([2, 1, 256])
dones: tensor([0], device='cuda:0'),  hidden: torch.Size([2, 1, 256])
dones: tensor([1], device='cuda:0'),  hidden: torch.Size([2, 1, 256])

rsl_rl/rsl_rl/modules/actor_critic_recurrent.py", line 101, in reset
    assert dones.max() < hidden_state.size(-2), f"dones {dones} index out of range {hidden_state.shape}"
AssertionError: dones tensor([1], device='cuda:0') index out of range torch.Size([2, 1, 256])

The logs( num_envs=4) are below, it will not result in error, but the index of hidden_state is not correct.

dones: tensor([0, 0, 0, 0], device='cuda:0'),  hidden: torch.Size([2, 4, 256])
dones: tensor([0, 0, 0, 0], device='cuda:0'),  hidden: torch.Size([2, 4, 256])
dones: tensor([0, 0, 0, 0], device='cuda:0'),  hidden: torch.Size([2, 4, 256])
dones: tensor([0, 0, 0, 0], device='cuda:0'),  hidden: torch.Size([2, 4, 256])
dones: tensor([0, 0, 0, 1], device='cuda:0'),  hidden: torch.Size([2, 4, 256])
dones: tensor([0, 0, 0, 1], device='cuda:0'),  hidden: torch.Size([2, 4, 256])

It can be found that the meaning of the elements in dones is whether each environment has ended. But what we need to reset are the ids of those ended environments. Therefore, the correct code is to find the envs whose dones are True.

    def reset(self, dones=None):
        # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
        # dones: (num_envs,), hidden_states: (num_layers, num_envs, hidden_size)
        dones_envs_id = torch.where(dones)[0] if dones else None
        for hidden_state in self.hidden_states:
            hidden_state[..., dones_envs_id, :] = 0.0

I don't know what the corresponding behavior is when done==True, so by default, all the memory of all environments will be set to 0.