google / dopamine

Dopamine is a research framework for fast prototyping of reinforcement learning algorithms.
https://github.com/google/dopamine
Apache License 2.0
10.42k stars 1.36k forks source link

Bug for truncated episodes in replaybuffer #213

Open theovincent opened 8 months ago

theovincent commented 8 months ago

It seems that the method is_valid_transition of OutOfGraphReplayBuffer is not checking if the stacked images are coming from another truncated trajectory, in which case the index is invalid.

It only checks if:

Here is a simple example of where it can be problematic:

import numpy as np
from dopamine.replay_memory.circular_replay_buffer import OutOfGraphReplayBuffer

replay_buffer = OutOfGraphReplayBuffer(observation_shape=(1,), stack_size=2, replay_capacity=10, batch_size=1)

replay_buffer.add(np.array([1]), 1, 1, False, episode_end=True)
replay_buffer.add(np.array([2]), 2, 2, False)

print(replay_buffer._store["observation"][:4])
print(replay_buffer.sample_transition_batch())
>>> [[0], [1], [0], [2]]  # there is no valid index to sample.
>>> (array([[[1, 0]]], dtype=uint8), array([0], dtype=int32), array([0.], dtype=float32), array([[[0, 2]]], dtype=uint8), array([2], dtype=int32), array([2.], dtype=float32), array([0], dtype=uint8), array([2], dtype=int32))

Here, index 2 is considered to be valid while it is not the case since the state array([[[1, 0]]]) is composed of an observation from the previous trajectory: [1] and a sample from the new trajectory: [0].

To solve this bug, https://github.com/google/dopamine/blob/ce36aab6528b26a699f5f1cefd330fdaf23a5d72/dopamine/replay_memory/circular_replay_buffer.py#L467 could be changed in:

for i in modulo_range(index - self._stack_size + 1, self._update_horizon, self._replay_capacity):