pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.05k stars 273 forks source link

[Feature Request] A Method to Modify ReplayBuffer In Place #2209

Closed wertyuilife2 closed 1 month ago

wertyuilife2 commented 1 month ago

Motivation

This issue comes from the original issue #2205.

My work requires modifying the contents of the buffer. Specifically, I need to sample an item, modify it, and put it back in the buffer. However, torchrl currently does not seem to encourage modifying buffer contents. When calling buffer._storage.set(index, data) to put my modified data back into the buffer, it implicitly changes _storage._len, which can cause the sampler to sample empty samples. The following code demonstrates this issue:

import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SliceSampler
from tensordict import TensorDict

def test_sampler():
    torch.manual_seed(0)

    sampler = SliceSampler(
        num_slices=2,
        traj_key="trajectory",
        strict_length=True,
    )
    trajectory = torch.tensor([4, 4, 1, 2, 2, 2, 3, 3, 3, 4])
    td = TensorDict({"trajectory": trajectory, "steps": torch.arange(10)}, [10])
    rb = ReplayBuffer(
        sampler=sampler,
        storage=LazyTensorStorage(20),
        batch_size=6,
    )
    rb.extend(td)

    for i in range(10):
        data, info = rb.sample(return_info=True)
        print("[loop {}]sampled trajectory: {}".format(i, data["trajectory"]))

        # I want to modify data and put it back
        # data = modify_func(data)
        rb._storage.set(info["index"], data)

        # The len(storage) increases due to rb._storage.set(),
        # causing sampling of undefined data(trajectory 0) in the future loop.
        print("[loop {}]len(storage): {}".format(i, len(rb._storage)))

test_sampler()

I resolved this by directly modifying buffer._storage._storage while holding the buffer._replay_lock. It took me two days to discover that TensorStorage.set() implicitly changes _len. I believe this method should behave more intuitively. I am not sure if other Storage classes have similar issues, but TensorStorage definitely does.

Solution

Provide a method that can modify ReplayBuffer in place, like Replaybuffer.set(index, data).

Additional context

See discussion in the original issue #2205.

Checklist

albertbou92 commented 1 month ago

It does not solve much but rb._storage._storage is the same as rb[:]. Is a bit more intuitive and clean (maybe you could call update() on rb[:][info["index"]])

wertyuilife2 commented 1 month ago

@albertbou92 yes, I use rb._storage._storage['key'][index_list] in practice.

wertyuilife2 commented 1 month ago

@albertbou92, oh I got you wrong. I mean it's clean to use rb[:], I agree. But I don't think it's a way that anyone will naturally figure out(and it should be called with buffer._replay_lock). Also, TensorStorage.set() seems like a usable method, but it doesn't actually act as expected, which may confuse others who use it.

albertbou92 commented 1 month ago

right, makes sense!