Closed wertyuilife2 closed 1 month ago
This issue comes from the original issue #2205.
In commit c2e1c05, at samplers.py line 1767 and 1213, index and stop_idx might not be on the same device, with stop_idx potentially being on the GPU. These lines should be modified as follows:
index
stop_idx
index[:, 0].unsqueeze(0)==stop_idx[:, 0].unsqueeze(1).to(index.device)
The code below will cause device error.
import torch from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage from torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler from tensordict import TensorDict def test_sampler(): torch.manual_seed(0) sampler = PrioritizedSliceSampler( max_capacity=20, num_slices=2, traj_key="trajectory", strict_length=True, alpha=1.0, beta=1.0, ) trajectory = torch.tensor([3, 3, 0, 1, 1, 1, 2, 2, 2, 3]) td = TensorDict({"trajectory": trajectory, "steps": torch.arange(10)}, [10]) rb = ReplayBuffer( sampler=sampler, storage=LazyTensorStorage(20, device=torch.device("cuda")), batch_size=6, ) rb.extend(td) for i in range(10): traj = rb.sample()["trajectory"] print("[loop {}]sampled trajectory: {}".format(i, traj)) test_sampler()
Describe the bug
This issue comes from the original issue #2205.
In commit c2e1c05, at samplers.py line 1767 and 1213,
index
andstop_idx
might not be on the same device, withstop_idx
potentially being on the GPU. These lines should be modified as follows:To Reproduce
The code below will cause device error.
Checklist