pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.07k stars 275 forks source link

[BUG] SliceSampler breaks when at capacity #1969

Closed nicklashansen closed 5 months ago

nicklashansen commented 5 months ago

Hi @vmoens, I ran into this issue with the SliceSampler which does not appear to be intentional behavior!

Describe the bug

The current implementation of SliceSampler will raise the exception

RuntimeError: Some stored trajectories have a length shorter than the slice that was asked for. Create the sampler with `strict_length=False` to allow shorter trajectories to appear in you batch.

when an added episode of length greater than slice_len is added to the replay buffer while it is close to capacity. It appears that episodes are "wrapped around" to the beginning of the replay buffer but that the sampler does not account for this and thus raises an exception.

This issue affects all use cases for which the replay buffer capacity is not a multiple of the episode length (or episodes with varying length).

To Reproduce

The error can be reproduced by running the following example code:

import torch
from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SliceSampler

def test_slicesampler():
    ep_len = 999
    capacity = 1000
    num_slices = 256
    slice_len = 4
    batch_size = num_slices * slice_len
    sampler = SliceSampler(
        num_slices=num_slices,
        end_key=None,
        traj_key='episode',
        truncated_key=None,
        strict_length=True,
    )
    storage = LazyTensorStorage(capacity, device=torch.device('cpu'))
    buffer = ReplayBuffer(
        storage=storage,
        sampler=sampler,
        batch_size=batch_size,
    )

    def create_episode(ep_len, ep_idx):
        return TensorDict(dict(
            episode=torch.ones(ep_len) * ep_idx,
            obs=torch.randn(ep_len, 6),
        ), batch_size=ep_len)

    buffer.extend(create_episode(ep_len, 0)) # works
    buffer.extend(create_episode(ep_len, 1)) # buffer wraps around

    buffer.sample() # raises error because episode 1 "is less than" slice_len

if __name__ == '__main__':
    test_slicesampler()

Expected behavior

I would expect the sampler to consider the off chance that an episode may be split between end indices and start indices due to the replay buffer being at capacity.

Checklist

vmoens commented 5 months ago

Interesting! I didn't think about that but it actually makes sense, the last traj of the storage spans across the end and the beginning of it.

I guess the natural fit would be to tell SliceSampler that this is the same traj!

I can patch that tomorrow

vmoens commented 5 months ago

I gave this a deeper thought and it's even trickier than I initially thought. Imagine we collect trajectories of length 3 in a buffer of length 10, we have

t0: [0, 0, 0, -1, -1, -1, -1, -1, -1, -1]

and after a while

t3: [3, 3, 0, 1, 1, 1, 2, 2, 2, 3]

Is strict_length=True, this will break because (1) the trajectory 3 spans across beginning and end and (2) the trajectory 0 has been partially overwritten.

(1) can be relatively easily solved, but (2) will persist.

So we should consider one of these options or a combination of them:

@nicklashansen what's the expected behaviour in your case in the example I gave? What would be the "natural" thing to do?

IMO (1) needs to be fixed for sure and strict length set to False, though I would then expect that strict_length will never be used then...

ccing @Cadene since we had a couple of conversations on the topic.

nicklashansen commented 5 months ago

@vmoens: this is the current usage of SliceSampler in TD-MPC2 https://github.com/nicklashansen/tdmpc2/blob/57158282b46ebc5c329c5be9cfe2b0094126d1ca/tdmpc2/common/buffer.py#L17 which sometimes results in the above error when the replay buffer becomes full. I think the most natural behavior would be to just not sample trajectories shorter than the specified length if strict_length=True, but also account for the possibility of the last trajectory wrapping around the end:start indices when sampling. In practice the capacity of the replay buffer will almost always be much larger than the length of an individual episode, so I doubt that excluding a single trajectory from sampling will have any noticeable effect on training. For example, TD-MPC2 uses a default buffer capacity of 1,000,000 while episode lengths are in the 10-10,000 range for most of the existing RL environments that I have encountered.

cc colleagues @dasGringuen @aalmuzairee who encountered this error

nicklashansen commented 4 months ago

@vmoens sorry to reopen this but we're still encountering errors when the replay buffer hits capacity.

This is the error that I'm getting using torchrl-nightly==2024.3.18:

  File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 603, in sample
    ret = self._prefetch_queue.popleft().result()
  File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/concurrent/futures/_base.py", line 433, in result
    return self.__get_result()
  File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/concurrent/futures/thread.py", line 52, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/torchrl/data/replay_buffers/utils.py", line 50, in decorated_fun
    output = fun(self, *args, **kwargs)
  File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 537, in _sample
    index, info = self._sampler.sample(self._storage, batch_size)
  File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/torchrl/data/replay_buffers/samplers.py", line 975, in sample
    return self._sample_slices(
  File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/torchrl/data/replay_buffers/samplers.py", line 1003, in _sample_slices
    raise RuntimeError(
RuntimeError: Did not find a single trajectory with sufficient length (length range: 2 - 251 / required=4)).

which is also encountered by @wertyuilife in issue https://github.com/nicklashansen/tdmpc2/issues/20. My specific implementation here uses variable episode length of 4-251 and encounters this error only at capacity, regardless of what I set the capacity to. I believe @wertyuilife encounters this error using the official TD-MPC2 repo which uses fixed episode length, so it appears to be a more persistent issue.

vmoens commented 4 months ago

No need to be sorry, I will investigate!

vmoens commented 4 months ago

Have you set strict_length=False by the way?

vmoens commented 4 months ago

I can reproduce this but it's expected when strict_length=True

import torch
import tqdm
from tensordict import TensorDict

from torchrl.data import ReplayBuffer, SliceSampler, LazyTensorStorage

rb = ReplayBuffer(storage=LazyTensorStorage(1000),
                  sampler=SliceSampler(slice_len=4, traj_key="traj", strict_length=False), batch_size=256) # Change strict_length=True to get the error

for i in tqdm.tqdm(range(10_000)):
    n = torch.randint(2, 50, ()).item()
    td = TensorDict({"a": torch.randn(n, 3), "traj": torch.full((n, ), i, dtype=torch.float32)}, [n])
    rb.extend(td)
    if i > 10:
        rb.sample()
nicklashansen commented 4 months ago

Have you set strict_length=False by the way?

This is the current usage of the SliceSampler: https://github.com/nicklashansen/tdmpc2/blob/57158282b46ebc5c329c5be9cfe2b0094126d1ca/tdmpc2/common/buffer.py#L17

self._sampler = SliceSampler(
  num_slices=self.cfg.batch_size,
  end_key=None,
  traj_key='episode',
  truncated_key=None,
)
self._batch_size = cfg.batch_size * (cfg.horizon+1)
[...]
return ReplayBuffer(
  storage=storage,
  sampler=self._sampler,
  pin_memory=True,
  prefetch=1,
  batch_size=self._batch_size,
)
[...]
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)

which returns the error

RuntimeError: Did not find a single trajectory with sufficient length (length range: 2 - 251 / required=4))

when strict_length=True (default). Setting strict_length=False instead results in the error

    td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
  File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/tensordict/utils.py", line 1177, in new_func
    out = func(_self, *args, **kwargs)
  File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/tensordict/base.py", line 1222, in view
    result = self._view(size=size) if size is not None else self._view(*shape)
  File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/tensordict/_td.py", line 1062, in _view
    shape = infer_size_impl(shape, self.numel())
  File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/torch/jit/_shape_functions.py", line 144, in infer_size_impl
    raise AssertionError("invalid shape")
AssertionError: invalid shape

presumably because the implementation returns a different number of elements than expected?

@vmoens What is the recommended way of using this sampler? Is there any way to just not sample sequences that are invalid? As it is, it seems like my implementation will break one way or the other as soon as the buffer hits capacity. I'm happy to change my own implementation to accommodate the SliceSampler but don't quite see an easy solution at the moment.

Also CC @wertyuilife @dasGringuen @rokas-bendikas @jyothirsv who are affected by this

vmoens commented 4 months ago

I think I see clearer now. The second error you're getting when strict_length=False is due to the fact that you're sampling at least one traj with fewer elements than expected. I will relax strict_length to make it not throw an exception when there are too few elements but just discard those episodes. It'll be in the minor early next week!

nicklashansen commented 4 months ago

Sounds great. Thank you!

nicklashansen commented 4 months ago

@vmoens I have checked that it no longer throws an exception for me and pushed an update to tdmpc2 here: https://github.com/nicklashansen/tdmpc2/commit/5f6fadec0fec78304b4b53e8171d348b58cac486

Thanks again for your help!