NM512 / dreamerv3-torch

Implementation of Dreamer v3 in pytorch.
MIT License
422 stars 96 forks source link

[Question] How does episode sampling handle environment resets? #37

Closed ahan98 closed 1 year ago

ahan98 commented 1 year ago

Hi, I'm confused about how episodes are sampled from the replay buffer, since episodes may have different lengths, and different episodes might be played in different environments due to resetting after a terminal state.

I still don't fully understand the sampling procedure, but from what I can tell based on sample_episodes(), it looks like episodes which end prematurely are padded with transitions from other episodes until you have sampled batch_size sequences of length batch_length.

For example, suppose batch_size=1 and batch_length=10, and the first episode you sample only has 3 transitions, e.g., (s_1, s_2), (s_2, s_3), (s_3, s_4). After the agent reaches terminal state s_4, the environment resets, and you obtain another episode of length 10, say, s'_1, ..., s'_10. Could we then train using a sequence such as (s_1, s_2), ..., (s_3, s_4), (s'_1, s'_2), ..., (s'_6, s'_7)? That is, is it okay to combine sequences from different episodes, even though the episodes may have been played in completely different environments?

Thanks for your time and for an amazing port of Dreamer!

NM512 commented 1 year ago

Hi,

Thank you for your inquiry about the sampling procedure from the replay buffer.

Your concern about combining sequences from different episodes is valid, but I'd like to assure you that it's handled by my current method. By using the "is_first" flag, we reset the hidden_state in the world model during training, allowing sequences from different episodes to be combined without issue. Please refer here.

So yes, it's okay to combine sequences as you described, and the model recognizes the boundaries between episodes, even if played in different environments.

I hope this answers your question. Feel free to reach out if you need further clarification.

Best, NM512