pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.34k stars 309 forks source link

[BUG] Multiple Issues in Samplers and Buffers Affecting Stability and Expected Behavior #2205

Open wertyuilife2 opened 5 months ago

wertyuilife2 commented 5 months ago

I have been using torchrl for my work recently and have encountered several bugs or unexpected behaviors. I noticed that @vmoens has been addressing some fixes, but to ensure nothing is overlooked, I am listing the issues I encountered here. These issues have not been resolved in torchrl-nightly==2024.6.3.

Given that this is a collection of issues, this post might be a bit lengthy. Please let me know if you would prefer me to split it into multiple issues!

In PrioritizedSliceSampler.sample() , preceding_stop_idx needs to be moved to the CPU before executing self._sum_tree[preceding_stop_idx] = 0.0. If preceding_stop_idx is on the GPU, the program results in a segmentation fault. It can be reproduced in the example code of 3rd issue.

In commit c2e1c05, at samplers.py line 1767 and 1213, index and stop_idx might not be on the same device, with stop_idx potentially being on the GPU. This line should be modified as follows:

index[:, 0].unsqueeze(0)==stop_idx[:, 0].unsqueeze(1).to(index.device)

As per the comments, the preceding_stop_idx variable in PrioritizedSliceSampler.sample() attempts to build a list of indexes that we don't want to sample: all the steps at a seq_length distance from the end of the trajectory, with the end of the trajectory (stop_idx) included. However, it does not do this correctly. The following code demonstrates this issue:

import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
from tensordict import TensorDict

def test_sampler():
    torch.manual_seed(0)

    sampler = PrioritizedSliceSampler(
        max_capacity=20,
        num_slices=2,
        traj_key="trajectory",
        strict_length=True,
        alpha=1.0,
        beta=1.0,
    )
    trajectory = torch.tensor([3, 3, 0, 1, 1, 1, 2, 2, 2, 3])
    td = TensorDict({"trajectory": trajectory, "steps": torch.arange(10)}, [10])
    rb = ReplayBuffer(
        sampler=sampler,
        storage=LazyTensorStorage(20, device=torch.device("cuda")),
        batch_size=6,
    )

    rb.extend(td)
    for i in range(10):
        # preceding_stop_idx in sample(): [5, 4, 8, 7], which should be [5, 4, 8, 7, 9, 0, 1, 2] or
        # [5, 4, 8, 7, 0, 1, 2], depending whether you want to ignore the spanning trajectories.
        traj = rb.sample()["trajectory"]
        print("[loop {}]sampled trajectory: {}".format(i, traj))

test_sampler()

This causes PrioritizedSliceSampler.sample to sample across trajectories, which is not the expected behavior, unlike SliceSampler which handles this correctly.

My work requires modifying the contents of the buffer. Specifically, I need to sample an item, modify it, and put it back in the buffer. However, torchrl currently does not seem to encourage modifying buffer contents. When calling buffer._storage.set(index, data) to put my modified data back into the buffer, it implicitly changes _storage._len, which can cause the sampler to sample empty samples. The following code demonstrates this issue:

import torch
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SliceSampler
from tensordict import TensorDict

def test_sampler():
    torch.manual_seed(0)

    sampler = SliceSampler(
        num_slices=2,
        traj_key="trajectory",
        strict_length=True,
    )
    trajectory = torch.tensor([4, 4, 1, 2, 2, 2, 3, 3, 3, 4])
    td = TensorDict({"trajectory": trajectory, "steps": torch.arange(10)}, [10])
    rb = ReplayBuffer(
        sampler=sampler,
        storage=LazyTensorStorage(20),
        batch_size=6,
    )
    rb.extend(td)

    for i in range(10):
        data, info = rb.sample(return_info=True)
        print("[loop {}]sampled trajectory: {}".format(i, data["trajectory"]))

        # I want to modify data and put it back
        # data = modify_func(data)
        rb._storage.set(info["index"], data)

        # The len(storage) increases due to rb._storage.set(),
        # causing sampling of undefined data(trajectory 0) in the future loop.
        print("[loop {}]len(storage): {}".format(i, len(rb._storage)))

test_sampler()

I resolved this by directly modifying buffer._storage._storage while holding the buffer._replay_lock. It took me two days to discover that TensorStorage.set() implicitly changes _len, and I believe this method should behave more intuitively. I am not sure if other Storage classes have similar issues, but TensorStorage definitely does.

[Current Implementation and Issues] The current implementation maintains _max_priority, which represents the maximum priority of all samples historically, not just the current buffer. Early in RL training, outliers can cause _max_priority to remain high, making it unrepresentative. Additionally, _max_priority is initialized to 1, while most RL algorithms use Bellman error as priority, which can often be much smaller (close to 0). Consequently, _max_priority may never be updated. New samples are thus given a priority of 1, which essentially means their PER weight is close to 0. This means they are sampled immediately but contribute little to the weighted loss, reducing sample efficiency.

[Proposed Solution] Maintain a _neg_min_tree = MinSegmentTree() to track the maximum priority in the current buffer. With this, and add self._upper_priority = 1, part of PrioritizedSampler methods can be updated as follows:

def default_priority(self, storage):
    max_priority = min(-self._neg_min_tree.query(0, len(storage)), self._upper_priority)
    if max_priority == 0:
        return self._upper_priority**self._alpha
    return (max_priority + self._eps) ** self._alpha

def mark_update(self, index, storage):
    self._neg_min_tree[index] = 0  # update negmintree before querying it in default_priority
    self.update_priority(index, self.default_priority(storage))

@torch.no_grad()
def _add_or_extend(self, index, storage):
    self._neg_min_tree[index] = 0  # update negmintree before querying it in default_priority
    priority = self.default_priority(storage)

    # ....codes in mid...

    self._sum_tree[index] = priority
    self._min_tree[index] = priority

@torch.no_grad()
def update_priority(self, index, priority):

    priority = torch.as_tensor(priority, device=torch.device("cpu")).detach()
    index = torch.as_tensor(index, dtype=torch.long, device=torch.device("cpu"))

    #...codes in mid...

    priority = torch.pow(priority + self._eps, self._alpha).clamp_max(
        self._upper_priority
    )
    self._sum_tree[index] = priority
    self._min_tree[index] = priority
    self._neg_min_tree[index] = -priority

This change implies that the default_priority function will need to take storage as an additional parameter, and eventually affecting several methods like Sampler(ABC).extend(), Sampler(ABC).add(), and Sampler(ABC).mark_update(), but I believe this is reasonable, akin to how Sampler.sample() already takes storage as a parameter.

When ReplayBuffer._add() is called, the following sequence occurs: (1) _writer.add() -> _storage.__setitem__() -> buffer.mark_update() -> _sampler.mark_update() -> _sampler.update_priority() (2) _sampler.add() -> _sampler._add_or_extend()

Both _sampler._add_or_extend() and _sampler.update_priority() update the priority, with update_priority() even applying additional transformations (e.g., torch.pow(priority + self._eps, self._alpha)). This behavior is also present in ReplayBuffer._extend().

This behavior is not reasonable. I believe the mark_update mechanism is somewhat redundant. We do not need to ensure that _attached_entities are updated when changing storage content. Any additional updates required after directly modifying _storage should be the responsibility of the user. mark_update can lead to redundant calls and even cause conflicts.

Although I haven't confirmed or tested this, there may be a threading risk. The ReplayBuffer initiates a separate thread for prefetching, which will call PrioritizedSliceSampler.sample().

In PrioritizedSliceSampler.sample(), we have:

# force to not sample index at the end of a trajectory
self._sum_tree[preceding_stop_idx] = 0.0
# and no need to update self._min_tree
starts, info = PrioritizedSampler.sample(self, storage=storage, batch_size=batch_size // seq_length)

If the main thread calls update_priority() between these two lines, it might update _sum_tree, causing self._sum_tree[preceding_stop_idx] = 0.0 to fail and then sample the end of a trajectory as the slice start.

I'm uncertain of the role of buffer._futures_lock, but it doesn't seem to prevent this conflict.


Overall, I hope this comprehensive overview helps in addressing these issues. Please let me know if you need further details or if I should break this down into separate issues!

Checklist

vmoens commented 5 months ago

Hi @wertyuilife2 Thanks for reporting these, for the detailed explanations and possible fixes!

I'm currently fixing PrioritizedSliceSampler which as you may have realized has some issues currently.

I think I can integrate 1-3 in https://github.com/pytorch/rl/pull/2202

For 4, this seems to be an additional feature, not a bug (as you say, TorchRL doesn't currently encourage modifying a storage in-place). Not saying we don't want to support that but we'd have to carefully look into it!

For 5 I'd be open to modify the logic around _max_priority to 0 and update it to the highest priority when the first samples come in. The issue here is that when we put the first batch in and don't call update_priority, there will be a segfault because the sum_tree is full of 0s. This is I believe the only reason we put a non-null priority. Typically, the usage is

buffer.extend(data)
s = buffer.sample()
loss_dict = loss(s)
buffer.update_priority(s["index"], s["priority"]) # or from the loss_dict

but if the initial priority is 0, the initial sample will fail. Your _upper_priority trick could solve that.

I agree with 6 too, the priority should be updated just once.

For 7, we can group that in the same stack as #2185 which also points at some thread safety issues. For these to be patched we'd need a good minimally reproducible example!

wertyuilife2 commented 5 months ago

Thanks for the quick response! I will test 1-3 for sure.

For 5, I completely understand your point. However, my key issue is that _max_priority tracks the highest priority ever recorded in history, whereas PER would use highest priority in the current buffer. For example, let's focus on the priorities of the first three samples in the buffer during an RL training process:

train loop 1:
index: [0,1,2,...]
priority: [0.9, 0.7, 0.8,...] 
_max_priority: 0.9
...
train loop 100:
index: [0,1,2,...]
priority: [0.01, 0.05, 0.02,...] 
_max_priority: 0.9

In the early stages of training, priorities are generally high, but as the network updates, the priorities gradually decrease. Because _max_priority tracks the historical maximum, new samples in the later stages of training are assigned excessively high default priorities. Since a sample's PER weight is calculated as min_priority / priority, this results in a value close to 0. Consequently, each new sample's weight in the first update is very small, almost negligible, effectively not participating in the update process.

For 7, I will try to test that and will provide feedback if I find anything.

vmoens commented 5 months ago

For 5, I completely understand your point. However, my key issue is that _max_priority tracks the highest priority ever recorded in history, whereas PER would use highest priority in the current buffer.

Oh yeah I can see that Then we should either record the index with the priority, and erase it whenever that index is overwritten, or compute it on the fly every time (which will be quite expensive!)

wertyuilife2 commented 5 months ago

In fact, _min_tree is used to store the minimum priority in the current buffer, which is then used in sample() to calculate the PER weight. Therefore, we can simply create a _neg_min_tree to store the maximum priority in the current buffer in a similar manner.

vmoens commented 5 months ago

Can I ask you to open 1 issue for each of these? It will be easier for me to track them!

wertyuilife2 commented 5 months ago

Of course, I will do it later!