Closed nicklashansen closed 5 months ago
Interesting! I didn't think about that but it actually makes sense, the last traj of the storage spans across the end and the beginning of it.
I guess the natural fit would be to tell SliceSampler that this is the same traj!
I can patch that tomorrow
I gave this a deeper thought and it's even trickier than I initially thought. Imagine we collect trajectories of length 3 in a buffer of length 10, we have
t0: [0, 0, 0, -1, -1, -1, -1, -1, -1, -1]
and after a while
t3: [3, 3, 0, 1, 1, 1, 2, 2, 2, 3]
Is strict_length=True
, this will break because (1) the trajectory 3 spans across beginning and end and (2) the trajectory 0 has been partially overwritten.
(1) can be relatively easily solved, but (2) will persist.
So we should consider one of these options or a combination of them:
strict_length
to False
by default. That will solve both but if we don't fix (1) you will always sample incomplete pieces of traj 3
in my example.strict_length
is True
only trajectories of sufficient size are considered, the rest is simply discarded. Not sure it's a very wise thing to propose or that anyone will use that...strict_length
(in both options above strict_length
becomes less relevant anyway) and simply give the option to the user of what to do with incomplete trajectories (pad and stack, not pad and cat along time with truncated...)@nicklashansen what's the expected behaviour in your case in the example I gave? What would be the "natural" thing to do?
IMO (1) needs to be fixed for sure and strict length set to False
, though I would then expect that strict_length
will never be used then...
ccing @Cadene since we had a couple of conversations on the topic.
@vmoens: this is the current usage of SliceSampler in TD-MPC2 https://github.com/nicklashansen/tdmpc2/blob/57158282b46ebc5c329c5be9cfe2b0094126d1ca/tdmpc2/common/buffer.py#L17 which sometimes results in the above error when the replay buffer becomes full. I think the most natural behavior would be to just not sample trajectories shorter than the specified length if strict_length=True
, but also account for the possibility of the last trajectory wrapping around the end:start indices when sampling. In practice the capacity of the replay buffer will almost always be much larger than the length of an individual episode, so I doubt that excluding a single trajectory from sampling will have any noticeable effect on training. For example, TD-MPC2 uses a default buffer capacity of 1,000,000 while episode lengths are in the 10-10,000 range for most of the existing RL environments that I have encountered.
cc colleagues @dasGringuen @aalmuzairee who encountered this error
@vmoens sorry to reopen this but we're still encountering errors when the replay buffer hits capacity.
This is the error that I'm getting using torchrl-nightly==2024.3.18
:
File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 603, in sample
ret = self._prefetch_queue.popleft().result()
File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/concurrent/futures/_base.py", line 433, in result
return self.__get_result()
File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/concurrent/futures/thread.py", line 52, in run
result = self.fn(*self.args, **self.kwargs)
File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/torchrl/data/replay_buffers/utils.py", line 50, in decorated_fun
output = fun(self, *args, **kwargs)
File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 537, in _sample
index, info = self._sampler.sample(self._storage, batch_size)
File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/torchrl/data/replay_buffers/samplers.py", line 975, in sample
return self._sample_slices(
File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/torchrl/data/replay_buffers/samplers.py", line 1003, in _sample_slices
raise RuntimeError(
RuntimeError: Did not find a single trajectory with sufficient length (length range: 2 - 251 / required=4)).
which is also encountered by @wertyuilife in issue https://github.com/nicklashansen/tdmpc2/issues/20. My specific implementation here uses variable episode length of 4-251 and encounters this error only at capacity, regardless of what I set the capacity to. I believe @wertyuilife encounters this error using the official TD-MPC2 repo which uses fixed episode length, so it appears to be a more persistent issue.
No need to be sorry, I will investigate!
Have you set strict_length=False by the way?
I can reproduce this but it's expected when strict_length=True
import torch
import tqdm
from tensordict import TensorDict
from torchrl.data import ReplayBuffer, SliceSampler, LazyTensorStorage
rb = ReplayBuffer(storage=LazyTensorStorage(1000),
sampler=SliceSampler(slice_len=4, traj_key="traj", strict_length=False), batch_size=256) # Change strict_length=True to get the error
for i in tqdm.tqdm(range(10_000)):
n = torch.randint(2, 50, ()).item()
td = TensorDict({"a": torch.randn(n, 3), "traj": torch.full((n, ), i, dtype=torch.float32)}, [n])
rb.extend(td)
if i > 10:
rb.sample()
Have you set strict_length=False by the way?
This is the current usage of the SliceSampler: https://github.com/nicklashansen/tdmpc2/blob/57158282b46ebc5c329c5be9cfe2b0094126d1ca/tdmpc2/common/buffer.py#L17
self._sampler = SliceSampler(
num_slices=self.cfg.batch_size,
end_key=None,
traj_key='episode',
truncated_key=None,
)
self._batch_size = cfg.batch_size * (cfg.horizon+1)
[...]
return ReplayBuffer(
storage=storage,
sampler=self._sampler,
pin_memory=True,
prefetch=1,
batch_size=self._batch_size,
)
[...]
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
which returns the error
RuntimeError: Did not find a single trajectory with sufficient length (length range: 2 - 251 / required=4))
when strict_length=True
(default). Setting strict_length=False
instead results in the error
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/tensordict/utils.py", line 1177, in new_func
out = func(_self, *args, **kwargs)
File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/tensordict/base.py", line 1222, in view
result = self._view(size=size) if size is not None else self._view(*shape)
File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/tensordict/_td.py", line 1062, in _view
shape = infer_size_impl(shape, self.numel())
File "/data/nihansen/miniconda3/envs/tdmpc2pp/lib/python3.9/site-packages/torch/jit/_shape_functions.py", line 144, in infer_size_impl
raise AssertionError("invalid shape")
AssertionError: invalid shape
presumably because the implementation returns a different number of elements than expected?
@vmoens What is the recommended way of using this sampler? Is there any way to just not sample sequences that are invalid? As it is, it seems like my implementation will break one way or the other as soon as the buffer hits capacity. I'm happy to change my own implementation to accommodate the SliceSampler but don't quite see an easy solution at the moment.
Also CC @wertyuilife @dasGringuen @rokas-bendikas @jyothirsv who are affected by this
I think I see clearer now. The second error you're getting when strict_length=False is due to the fact that you're sampling at least one traj with fewer elements than expected. I will relax strict_length to make it not throw an exception when there are too few elements but just discard those episodes. It'll be in the minor early next week!
Sounds great. Thank you!
@vmoens I have checked that it no longer throws an exception for me and pushed an update to tdmpc2 here: https://github.com/nicklashansen/tdmpc2/commit/5f6fadec0fec78304b4b53e8171d348b58cac486
Thanks again for your help!
Hi @vmoens, I ran into this issue with the SliceSampler which does not appear to be intentional behavior!
Describe the bug
The current implementation of SliceSampler will raise the exception
when an added episode of length greater than slice_len is added to the replay buffer while it is close to capacity. It appears that episodes are "wrapped around" to the beginning of the replay buffer but that the sampler does not account for this and thus raises an exception.
This issue affects all use cases for which the replay buffer capacity is not a multiple of the episode length (or episodes with varying length).
To Reproduce
The error can be reproduced by running the following example code:
Expected behavior
I would expect the sampler to consider the off chance that an episode may be split between end indices and start indices due to the replay buffer being at capacity.
Checklist