Open phisad opened 2 months ago
Hello, that's a good suggestion =) Would you mind giving it a try and check that you obtain the exact same results? If so, please open a PR ;)
That would simplify and make things much faster hopefully.
Alright, thanks for the confirmation. ^^
I'll give a try and make sure that these tests run through https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/tests/test_lstm.py without any errors (and maybe even a bit faster).
Thinking again about that issue, I'm afraid we still need https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/25b43266e08ebe258061ac69688d94144799de75/sb3_contrib/common/recurrent/policies.py#L203-L204 to reset states manually when starting a new episode? (at least when updating the network, when calling train()
)
or can we pass all hidden states to PyTorch?
❓ Question
Hi and thanks for the great work!
I am using RecurrentPPO in a current project and it strikes me that on L294 the
self._last_lstm_states
added to the buffer are actually the one from the last terminal state (and not all zeros), when an environment is reset on L252. Is my understanding correct?If so, would it not be better to check for an episode start already one line before L242 and set the states to zero for those environments instead of handling this in
_process_sequence
ofRecurrentActorCriticPolicy
L198 on each forward pass?Checklist