Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
https://sb3-contrib.readthedocs.io
MIT License
442 stars 166 forks source link

[Question] RecurrentPPO: Reset LSTM states early? #239

Open phisad opened 2 months ago

phisad commented 2 months ago

❓ Question

Hi and thanks for the great work!

I am using RecurrentPPO in a current project and it strikes me that on L294 the self._last_lstm_states added to the buffer are actually the one from the last terminal state (and not all zeros), when an environment is reset on L252. Is my understanding correct?

If so, would it not be better to check for an episode start already one line before L242 and set the states to zero for those environments instead of handling this in _process_sequence of RecurrentActorCriticPolicy L198 on each forward pass?

Checklist

araffin commented 2 months ago

Hello, that's a good suggestion =) Would you mind giving it a try and check that you obtain the exact same results? If so, please open a PR ;)

That would simplify and make things much faster hopefully.

phisad commented 2 months ago

Alright, thanks for the confirmation. ^^

I'll give a try and make sure that these tests run through https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/tests/test_lstm.py without any errors (and maybe even a bit faster).

araffin commented 2 months ago

Thinking again about that issue, I'm afraid we still need https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/25b43266e08ebe258061ac69688d94144799de75/sb3_contrib/common/recurrent/policies.py#L203-L204 to reset states manually when starting a new episode? (at least when updating the network, when calling train())

or can we pass all hidden states to PyTorch?