It seems that the method is_valid_transition of OutOfGraphReplayBuffer is not checking if the stacked images are coming from another truncated trajectory, in which case the index is invalid.
>>> [[0], [1], [0], [2]] # there is no valid index to sample.
>>> (array([[[1, 0]]], dtype=uint8), array([0], dtype=int32), array([0.], dtype=float32), array([[[0, 2]]], dtype=uint8), array([2], dtype=int32), array([2.], dtype=float32), array([0], dtype=uint8), array([2], dtype=int32))
Here, index 2 is considered to be valid while it is not the case since the state array([[[1, 0]]]) is composed of an observation from the previous trajectory: [1] and a sample from the new trajectory: [0].
It seems that the method is_valid_transition of OutOfGraphReplayBuffer is not checking if the stacked images are coming from another truncated trajectory, in which case the index is invalid.
It only checks if:
Here is a simple example of where it can be problematic:
Here, index 2 is considered to be valid while it is not the case since the state array([[[1, 0]]]) is composed of an observation from the previous trajectory: [1] and a sample from the new trajectory: [0].
To solve this bug, https://github.com/google/dopamine/blob/ce36aab6528b26a699f5f1cefd330fdaf23a5d72/dopamine/replay_memory/circular_replay_buffer.py#L467 could be changed in: