pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.05k stars 273 forks source link

[BUG] Potential Thread Safety Issue with PrioritizedSliceSampler and Prefetching #2212

Open wertyuilife2 opened 1 month ago

wertyuilife2 commented 1 month ago

Describe the bug

This issue comes from the original issue #2205.

Although I haven't confirmed or tested this, there may be a threading risk. The ReplayBuffer initiates a separate thread for prefetching, which will call PrioritizedSliceSampler.sample().

In PrioritizedSliceSampler.sample(), we have:

# force to not sample index at the end of a trajectory
self._sum_tree[preceding_stop_idx] = 0.0
# and no need to update self._min_tree
starts, info = PrioritizedSampler.sample(self, storage=storage, batch_size=batch_size // seq_length)

If the main thread calls update_priority() between these two lines, it might update _sum_tree, causing self._sum_tree[preceding_stop_idx] = 0.0 to fail and then sample the end of a trajectory as the slice start.

I'm uncertain of the role of buffer._futures_lock, but it doesn't seem to prevent this conflict.

Additional context

See discussion in the original issue #2205.

Checklist

cc @xiaomengy