Open wertyuilife2 opened 5 months ago
Hi @wertyuilife2 Thanks for reporting these, for the detailed explanations and possible fixes!
I'm currently fixing PrioritizedSliceSampler
which as you may have realized has some issues currently.
I think I can integrate 1-3 in https://github.com/pytorch/rl/pull/2202
For 4, this seems to be an additional feature, not a bug (as you say, TorchRL doesn't currently encourage modifying a storage in-place). Not saying we don't want to support that but we'd have to carefully look into it!
For 5 I'd be open to modify the logic around _max_priority
to 0 and update it to the highest priority when the first samples come in. The issue here is that when we put the first batch in and don't call update_priority
, there will be a segfault because the sum_tree is full of 0s. This is I believe the only reason we put a non-null priority.
Typically, the usage is
buffer.extend(data)
s = buffer.sample()
loss_dict = loss(s)
buffer.update_priority(s["index"], s["priority"]) # or from the loss_dict
but if the initial priority is 0, the initial sample will fail. Your _upper_priority
trick could solve that.
I agree with 6 too, the priority should be updated just once.
For 7, we can group that in the same stack as #2185 which also points at some thread safety issues. For these to be patched we'd need a good minimally reproducible example!
Thanks for the quick response! I will test 1-3 for sure.
For 5, I completely understand your point. However, my key issue is that _max_priority
tracks the highest priority ever recorded in history, whereas PER would use highest priority in the current buffer.
For example, let's focus on the priorities of the first three samples in the buffer during an RL training process:
train loop 1:
index: [0,1,2,...]
priority: [0.9, 0.7, 0.8,...]
_max_priority: 0.9
...
train loop 100:
index: [0,1,2,...]
priority: [0.01, 0.05, 0.02,...]
_max_priority: 0.9
In the early stages of training, priorities are generally high, but as the network updates, the priorities gradually decrease. Because _max_priority
tracks the historical maximum, new samples in the later stages of training are assigned excessively high default priorities.
Since a sample's PER weight is calculated as min_priority / priority
, this results in a value close to 0. Consequently, each new sample's weight in the first update is very small, almost negligible, effectively not participating in the update process.
For 7, I will try to test that and will provide feedback if I find anything.
For 5, I completely understand your point. However, my key issue is that _max_priority tracks the highest priority ever recorded in history, whereas PER would use highest priority in the current buffer.
Oh yeah I can see that Then we should either record the index with the priority, and erase it whenever that index is overwritten, or compute it on the fly every time (which will be quite expensive!)
In fact, _min_tree
is used to store the minimum priority in the current buffer, which is then used in sample() to calculate the PER weight. Therefore, we can simply create a _neg_min_tree
to store the maximum priority in the current buffer in a similar manner.
Can I ask you to open 1 issue for each of these? It will be easier for me to track them!
Of course, I will do it later!
I have been using
torchrl
for my work recently and have encountered several bugs or unexpected behaviors. I noticed that @vmoens has been addressing some fixes, but to ensure nothing is overlooked, I am listing the issues I encountered here. These issues have not been resolved intorchrl-nightly==2024.6.3
.Given that this is a collection of issues, this post might be a bit lengthy. Please let me know if you would prefer me to split it into multiple issues!
In
PrioritizedSliceSampler.sample()
,preceding_stop_idx
needs to be moved to the CPU before executingself._sum_tree[preceding_stop_idx] = 0.0
. Ifpreceding_stop_idx
is on the GPU, the program results in a segmentation fault. It can be reproduced in the example code of 3rd issue.In commit c2e1c05, at samplers.py line 1767 and 1213,
index
andstop_idx
might not be on the same device, withstop_idx
potentially being on the GPU. This line should be modified as follows:As per the comments, the
preceding_stop_idx
variable inPrioritizedSliceSampler.sample()
attempts tobuild a list of indexes that we don't want to sample: all the steps at a seq_length distance from the end of the trajectory, with the end of the trajectory (stop_idx) included.
However, it does not do this correctly. The following code demonstrates this issue:This causes
PrioritizedSliceSampler.sample
to sample across trajectories, which is not the expected behavior, unlikeSliceSampler
which handles this correctly.My work requires modifying the contents of the buffer. Specifically, I need to sample an item, modify it, and put it back in the buffer. However, torchrl currently does not seem to encourage modifying buffer contents. When calling
buffer._storage.set(index, data)
to put my modified data back into the buffer, it implicitly changes_storage._len
, which can cause the sampler to sample empty samples. The following code demonstrates this issue:I resolved this by directly modifying
buffer._storage._storage
while holding thebuffer._replay_lock
. It took me two days to discover thatTensorStorage.set()
implicitly changes_len
, and I believe this method should behave more intuitively. I am not sure if otherStorage
classes have similar issues, butTensorStorage
definitely does.[Current Implementation and Issues] The current implementation maintains
_max_priority
, which represents the maximum priority of all samples historically, not just the current buffer. Early in RL training, outliers can cause_max_priority
to remain high, making it unrepresentative. Additionally,_max_priority
is initialized to 1, while most RL algorithms use Bellman error as priority, which can often be much smaller (close to 0). Consequently,_max_priority
may never be updated. New samples are thus given a priority of 1, which essentially means their PER weight is close to 0. This means they are sampled immediately but contribute little to the weighted loss, reducing sample efficiency.[Proposed Solution] Maintain a
_neg_min_tree = MinSegmentTree()
to track the maximum priority in the current buffer. With this, and addself._upper_priority = 1
, part ofPrioritizedSampler
methods can be updated as follows:This change implies that the
default_priority
function will need to takestorage
as an additional parameter, and eventually affecting several methods likeSampler(ABC).extend()
,Sampler(ABC).add()
, andSampler(ABC).mark_update()
, but I believe this is reasonable, akin to howSampler.sample()
already takes storage as a parameter.When
ReplayBuffer._add()
is called, the following sequence occurs: (1)_writer.add() -> _storage.__setitem__() -> buffer.mark_update() -> _sampler.mark_update() -> _sampler.update_priority()
(2)_sampler.add() -> _sampler._add_or_extend()
Both
_sampler._add_or_extend()
and_sampler.update_priority()
update the priority, withupdate_priority()
even applying additional transformations (e.g.,torch.pow(priority + self._eps, self._alpha)
). This behavior is also present inReplayBuffer._extend()
.This behavior is not reasonable. I believe the
mark_update
mechanism is somewhat redundant. We do not need to ensure that_attached_entities
are updated when changingstorage
content. Any additional updates required after directly modifying_storage
should be the responsibility of the user.mark_update
can lead to redundant calls and even cause conflicts.Although I haven't confirmed or tested this, there may be a threading risk. The ReplayBuffer initiates a separate thread for prefetching, which will call
PrioritizedSliceSampler.sample()
.In
PrioritizedSliceSampler.sample()
, we have:If the main thread calls
update_priority()
between these two lines, it might update_sum_tree
, causingself._sum_tree[preceding_stop_idx] = 0.0
to fail and then sample the end of a trajectory as the slice start.I'm uncertain of the role of
buffer._futures_lock
, but it doesn't seem to prevent this conflict.Overall, I hope this comprehensive overview helps in addressing these issues. Please let me know if you need further details or if I should break this down into separate issues!
Checklist