pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.05k stars 273 forks source link

[BUG] Device Error in PrioritizedSliceSampler.sample() #2207

Closed wertyuilife2 closed 1 month ago

wertyuilife2 commented 1 month ago

Describe the bug

This issue comes from the original issue #2205.

In commit c2e1c05, at samplers.py line 1767 and 1213, index and stop_idx might not be on the same device, with stop_idx potentially being on the GPU. These lines should be modified as follows:

index[:, 0].unsqueeze(0)==stop_idx[:, 0].unsqueeze(1).to(index.device)

To Reproduce

The code below will cause device error.

import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
from tensordict import TensorDict

def test_sampler():
    torch.manual_seed(0)

    sampler = PrioritizedSliceSampler(
        max_capacity=20,
        num_slices=2,
        traj_key="trajectory",
        strict_length=True,
        alpha=1.0,
        beta=1.0,
    )
    trajectory = torch.tensor([3, 3, 0, 1, 1, 1, 2, 2, 2, 3])
    td = TensorDict({"trajectory": trajectory, "steps": torch.arange(10)}, [10])
    rb = ReplayBuffer(
        sampler=sampler,
        storage=LazyTensorStorage(20, device=torch.device("cuda")),
        batch_size=6,
    )

    rb.extend(td)
    for i in range(10):
        traj = rb.sample()["trajectory"]
        print("[loop {}]sampled trajectory: {}".format(i, traj))
test_sampler()

Checklist