Closed wertyuilife2 closed 3 months ago
This occurs because, although the sample at index = len(storage) - 1 is covered by preceding_stop_idx, it can still be sampled as start index, leading to cross-trajectory sampling.
Is this also the case with the patch in #2228? How can the start be equal to len(storage)-1 if this is masked?
It's a new issue due to _sum_tree
, I think it can't be solved with #2228.
In PrioritizedSampler.sample()
we have:
def sample(...)
...
index = self._sum_tree.scan_lower_bound(mass)
...
index.clamp_max_(len(storage) - 1)
weight = torch.as_tensor(self._sum_tree[index])
...
return index, {"_weight": weight}
For some unknown bug in _sum_tree
, we randomly get an index greater than len(storage), then it's clamped to len(storage) - 1, which is why we get the wrong start index.
cc @xiaomengy any idea why we get out-of-bound samples from the PRB?
In
PrioritizedSampler.sample()
,_sum_tree.scan_lower_bound()
sometimes generates index greater than len(storage).The test code is as follows:
In
PrioritizedSampler.sample()
, this unexpected behavior is handled byindex.clamp_max_(len(storage) - 1)
, but it still causes unexpected behavior inPrioritizedSliceSampler.sample()
.This occurs because, although the sample at
index = len(storage) - 1
is covered bypreceding_stop_idx
, it can still be sampled as start index, leading to cross-trajectory sampling.Checklist