luchris429 / purejaxrl

Really Fast End-to-End Jax RL Implementations
Apache License 2.0
738 stars 62 forks source link

RNNs hidden resets #23

Closed esraaelelimy closed 7 months ago

esraaelelimy commented 7 months ago

In the rnn_ppo implementations, the rnn uses the done signal at time t to reset the hidden state, but shouldn't it use the done at {t-1} instead? From my understanding, we reset the hidden states at the beginning of the episode, and to know if an observation ot is the start of an episode, we should check done{t-1}, not done_{t}?

esraaelelimy commented 7 months ago

actually, I think the implementation does that, but it wasn't clear at first. Looking at Gymnax environments' implementations, if the episode terminates, the returned observation is the start of the new episode, not the terminal observation. So, we get (Observation_0, done_T,...), not (Observation_T, done_T,..). Hence, using the current returned done signal makes sense when resetting the hidden states.