pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.05k stars 273 forks source link

[BUG] Segmentation Fault in PrioritizedSliceSampler.sample() #2206

Closed wertyuilife2 closed 1 month ago

wertyuilife2 commented 1 month ago

Describe the bug

This issue comes from the original issue #2205.

In PrioritizedSliceSampler.sample() , preceding_stop_idx needs to be moved to the CPU before executing self._sum_tree[preceding_stop_idx] = 0.0. If preceding_stop_idx is on the GPU, the program results in a segmentation fault.

To Reproduce

The code below will cause segmentation fault.

import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
from tensordict import TensorDict

def test_sampler():
    torch.manual_seed(0)

    sampler = PrioritizedSliceSampler(
        max_capacity=20,
        num_slices=2,
        traj_key="trajectory",
        strict_length=True,
        alpha=1.0,
        beta=1.0,
    )
    trajectory = torch.tensor([3, 3, 0, 1, 1, 1, 2, 2, 2, 3])
    td = TensorDict({"trajectory": trajectory, "steps": torch.arange(10)}, [10])
    rb = ReplayBuffer(
        sampler=sampler,
        storage=LazyTensorStorage(20, device=torch.device("cuda")),
        batch_size=6,
    )

    rb.extend(td)
    for i in range(10):
        traj = rb.sample()["trajectory"]
        print("[loop {}]sampled trajectory: {}".format(i, traj))
test_sampler()

Checklist

vmoens commented 1 month ago

Should be solved by #2202

vmoens commented 1 month ago

Not solved yet - bear with me