pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.05k stars 273 forks source link

[BUG] Unintended Cross-Trajectory Sampling in PrioritizedSliceSampler.sample() #2208

Closed wertyuilife2 closed 1 month ago

wertyuilife2 commented 1 month ago

Describe the bug

This issue comes from the original issue #2205.

As per the comments, the preceding_stop_idx variable in PrioritizedSliceSampler.sample() attempts to build a list of indexes that we don't want to sample: all the steps at a seq_length distance from the end of the trajectory, with the end of the trajectory (stop_idx) included. However, it does not do this correctly.

To Reproduce

The following code demonstrates this issue:

import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
from tensordict import TensorDict

def test_sampler():
    torch.manual_seed(0)

    sampler = PrioritizedSliceSampler(
        max_capacity=20,
        num_slices=2,
        traj_key="trajectory",
        strict_length=True,
        alpha=1.0,
        beta=1.0,
    )
    trajectory = torch.tensor([3, 3, 0, 1, 1, 1, 2, 2, 2, 3])
    td = TensorDict({"trajectory": trajectory, "steps": torch.arange(10)}, [10])
    rb = ReplayBuffer(
        sampler=sampler,
        storage=LazyTensorStorage(20, device=torch.device("cuda")),
        batch_size=6,
    )

    rb.extend(td)
    for i in range(10):
        # preceding_stop_idx in sample(): [5, 4, 8, 7], which should be [5, 4, 8, 7, 9, 0, 1, 2] or
        # [5, 4, 8, 7, 0, 1, 2], depending whether you want to ignore the spanning trajectories.
        traj = rb.sample()["trajectory"]
        print("[loop {}]sampled trajectory: {}".format(i, traj))

test_sampler()

This causes PrioritizedSliceSampler.sample to sample across trajectories, which is not the expected behavior, unlike SliceSampler which handles this correctly.

Checklist

vmoens commented 1 month ago

Should be solved by #2202

vmoens commented 1 month ago

Works under #2202, both with strict_length=True and False

wertyuilife2 commented 1 month ago

@vmoens this issue has not been resolved cause I made a significant mistake in the test code (my bad)!

In the original test code, the sizes of sampler.max_capacity and storage.max_size should be 10 instead of 20 to make the buffer reach its full capacity. Therefore, the correct test code is as follows:

import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SliceSampler, PrioritizedSliceSampler
from tensordict import TensorDict

def test_sampler():
    torch.manual_seed(0)

    sampler = PrioritizedSliceSampler(
        max_capacity=10,
        num_slices=2,
        traj_key="trajectory",
        strict_length=True,
        alpha=1.0,
        beta=1.0,
    )
    # sampler = SliceSampler(
    #     num_slices=2,
    #     traj_key='trajectory',
    #     strict_length=True,
    #     span=True
    # )
    trajectory = torch.tensor([3, 0, 1, 1, 1, 2, 2, 2, 3, 3])
    td = TensorDict({"trajectory": trajectory, "steps": torch.arange(10)}, [10])

    rb = ReplayBuffer(
        sampler=sampler,
        storage=LazyTensorStorage(10, device=torch.device("cuda")),
        batch_size=6,
    )

    rb.extend(td)
    for i in range(10):
        # preceding_stop_idx in sample(): [1 2 3 5 6 8 9]
        traj = rb.sample()["trajectory"]
        print("[loop {}]sampled trajectory: {}".format(i, traj))

test_sampler()

with new test code, two issues arise here (tested on torchrl-nightly==2024.6.13):

  1. PrioritizedSliceSampler still samples across trajectories.
  2. SliceSampler with span=True throws an error.
vmoens commented 1 month ago

2228 will fix this, but be mindful that I don't think that in this example we can really sample traj 3.

When we populate the buffer, the cursor is at 0. When we're done writing things, the cursor should be at 1 (that's the end of traj 3 which overlaps the end and beginning of the storage). So the cursor is not where intended, and we tell the sampler that the cursor points to a place where we have a truncated signal. So for the buffer, the first [3] is not part of the same trajectory as the ones that are at the end of the buffer. If you want to test this properly, you can do this

import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SliceSampler, PrioritizedSliceSampler
from tensordict import TensorDict

def test_sampler():
    torch.manual_seed(0)

    sampler = PrioritizedSliceSampler(
        max_capacity=10,
        num_slices=2,
        traj_key="trajectory",
        # end_key="done",
        strict_length=True,
        alpha=1.0,
        beta=1.0,
    )
    trajectory0 = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
    done0 = torch.tensor([False, True, False, False, True, False, False, True, False, False])
    td0 = TensorDict({"trajectory": trajectory0, "steps": torch.arange(10), "done": done0}, [10])
    trajectory1 = torch.tensor([3])
    done1 = torch.tensor([False])
    td1 = TensorDict({"trajectory": trajectory1, "steps": torch.tensor([10]), "done": done1}, [1])

    rb = ReplayBuffer(
        sampler=sampler,
        storage=LazyTensorStorage(10, device=torch.device("cpu")),
        batch_size=6,
    )

    rb.extend(td0)
    rb.extend(td1)
    for i in range(10):
        # preceding_stop_idx in sample(): [1 2 3 5 6 8 9]
        s, info = rb.sample(return_info=True)
        traj = s["trajectory"]
        print("[loop {}] sampled trajectory: {}".format(i, traj))
        print("[loop {}] index {}".format(i, info["index"]))
        assert len(traj.unique())<=2

test_sampler()
wertyuilife2 commented 1 month ago

Oh I'm okay with that. I think missing a single trajectory has little impact on most RL algorithms.