pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.01k stars 269 forks source link

[BUG] Error reloading a non-initialised buffer #2264

Open matteobettini opened 1 week ago

matteobettini commented 1 week ago
from torchrl.data import RandomSampler, TensorDictReplayBuffer, LazyTensorStorage

buffer = TensorDictReplayBuffer(
    storage=LazyTensorStorage(
        100, 
        device="cpu",
    ),
    sampler=RandomSampler(),
    batch_size=100,
)

buffer.load_state_dict(buffer.state_dict())
Traceback (most recent call last):
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/prova.py", line 17, in <module>
    buffer.load_state_dict(buffer.state_dict())
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/data/replay_buffers/replay_buffers.py", line 420, in load_state_dict
    self._storage.load_state_dict(state_dict["_storage"])
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/data/replay_buffers/storages.py", line 594, in load_state_dict
    self._storage = TensorDict({}, []).load_state_dict(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2604, in load_state_dict
    batch_size = state_dict.pop("__batch_size")
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: '__batch_size'
vmoens commented 1 week ago

Is that supposed to be supported? What's the use case?

matteobettini commented 1 week ago

I have a list of buffers I am reloading, the storage of some of them hasn't been used and in others it has.

Why would this not be supported?

vmoens commented 1 week ago

because it's not initialized so we don't know what to write. If you have a buffer that is initialized and load a non-initialized one, what would be the behaviour? Unless that is clearly defined, I would go for a proper exception in this case.

matteobettini commented 1 week ago

Yes on that I agree, but in this case we are loading the state_dict of a non-initialised buffer in a non-initialised buffer. I think we should be able to do that (it might be no-op or smth similar)

vmoens commented 1 week ago

hmmm I'd rather avoid things that open the door to UBs. We can have a flag in the sd that says that it isn't initialized and I would raise a warning telling the user that this is a dangerous thing to do so state_dict from uninit -> ok but warning load_state_dict from uninit on uninit -> ok load_state_dict from uninit on init -> error

The problem is that the last case falls upon load_state_dict to handle and because it's a UB it may or may not lead to a proper error (we can't know in advance if a user subclasses a storage if they will or will not have an error in that case).

Really not sure of what to do here. IMO this should be

if storage.is_init:
    sd = storage.state_dict()
else:
    sd = None
# later
if sd is not None:
    storage.load_state_dict(sd)

replace with if len(rb) > 0: at the RB level