pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.24k stars 296 forks source link

[BUG] Unexpected behavior of SumSegmentTree Resulting in Invalid Slices in PrioritizedSliceSampler.sample() #2230

Closed wertyuilife2 closed 3 months ago

wertyuilife2 commented 3 months ago

In PrioritizedSampler.sample(), _sum_tree.scan_lower_bound() sometimes generates index greater than len(storage).

The test code is as follows:

import torch
from torchrl._torchrl import SumSegmentTreeFp32
import numpy as np

def test_sum_tree():
    torch.manual_seed(0)
    np.random.seed(0)
    sum_tree = SumSegmentTreeFp32(500)

    # repeat to ensure the bug happens
    for _ in range(1000):
        # update priority
        index = torch.arange(0,100, dtype=torch.long, device=torch.device("cpu"))
        priority = torch.rand(100, device=torch.device("cpu"))
        sum_tree[index] = priority+1e-8 # 1e-8 are not essential, w/o 1e-8, bug still happens

        # sample
        p_sum = sum_tree.query(0, 100)
        mass = np.random.uniform(0.0, p_sum, size=1000000) # sample a lot to ensure the bug happens.
        scanned_index = sum_tree.scan_lower_bound(mass)
        if scanned_index.max()>=100:
            print("Unexpected index! p_sum:{}, mass.max():{}, scanned_index.max():{}".format(p_sum, mass.max(),scanned_index.max()))

test_sum_tree()

In PrioritizedSampler.sample(), this unexpected behavior is handled by index.clamp_max_(len(storage) - 1), but it still causes unexpected behavior in PrioritizedSliceSampler.sample().

This occurs because, although the sample at index = len(storage) - 1 is covered by preceding_stop_idx, it can still be sampled as start index, leading to cross-trajectory sampling.

Checklist

vmoens commented 3 months ago

This occurs because, although the sample at index = len(storage) - 1 is covered by preceding_stop_idx, it can still be sampled as start index, leading to cross-trajectory sampling.

Is this also the case with the patch in #2228? How can the start be equal to len(storage)-1 if this is masked?

wertyuilife2 commented 3 months ago

It's a new issue due to _sum_tree, I think it can't be solved with #2228.

In PrioritizedSampler.sample() we have:

def sample(...)
    ...
    index = self._sum_tree.scan_lower_bound(mass)
    ...
    index.clamp_max_(len(storage) - 1)
    weight = torch.as_tensor(self._sum_tree[index])
    ...
    return index, {"_weight": weight}

For some unknown bug in _sum_tree, we randomly get an index greater than len(storage), then it's clamped to len(storage) - 1, which is why we get the wrong start index.

vmoens commented 3 months ago

cc @xiaomengy any idea why we get out-of-bound samples from the PRB?