pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.39k stars 319 forks source link

[BUG] SyncDataCollector Crashes when init_random_frames=0 #2534

Open AlexandreBrown opened 4 weeks ago

AlexandreBrown commented 4 weeks ago

Describe the bug

When yielding from a SyncDataCollector that uses a standard Actor (not a random policy) and init_random_frames=0, it crashes.

policy = Actor(
        agent,
        in_keys=["your_key"],
        out_keys=["action"],
        spec=train_env.action_spec,
    )
train_data_collector = SyncDataCollector(
        create_env_fn=train_env,
        policy=policy,
        init_random_frames=0,
        ...
    )

Yielding example that causes the crash :

for data in tqdm(train_data_collector, "Env Data Collection"):

To Reproduce

  1. Create an actor that is not RandomPolicy
  2. Create a SyncDataCollector with the actor and set init_random_frames=0.
  3. Try to yield from the data collector
  4. Observe the crash

Stack trace:

2024-11-04 12:04:33,606 [torchrl][INFO] check_env_specs succeeded!
2024-11-04 12:04:36.365 | INFO     | __main__:main:60 - Policy Device: cuda
2024-11-04 12:04:36.365 | INFO     | __main__:main:61 - Env Device: cpu
2024-11-04 12:04:36.365 | INFO     | __main__:main:62 - Storage Device: cpu
Env Data Collection:   0%|                                                                                                                                      | 0/1000000 [00:00<?, ?it/s]
Error executing job with overrides: ['env=dmc_reacher_hard', 'algo=sac_pixels']
Traceback (most recent call last):
  File "/home/user/Documents/SegDAC/./scripts/train_rl.py", line 119, in main
    trainer.train()
  File "/home/user/Documents/SegDAC/segdac_dev/src/segdac_dev/trainers/rl_trainer.py", line 40, in train
    for data in tqdm(self.train_data_collector, "Env Data Collection"):
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 247, in __iter__
    yield from self.iterator()
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1035, in iterator
    tensordict_out = self.rollout()
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/_utils.py", line 481, in unpack_rref_and_invoke_function
    return func(self, *args, **kwargs)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1166, in rollout
    env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/common.py", line 2862, in step_and_maybe_reset
    tensordict = self.step(tensordict)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/common.py", line 1505, in step
    next_tensordict = self._step(tensordict)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 783, in _step
    tensordict_in = self.transform.inv(tensordict)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/tensordict/nn/common.py", line 314, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 357, in inv
    out = self._inv_call(clone(tensordict))
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 1084, in _inv_call
    tensordict = t._inv_call(tensordict)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 3656, in _inv_call
    return super()._inv_call(tensordict)
  File "/home/user/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 342, in _inv_call
    raise KeyError(f"'{in_key}' not found in tensordict {tensordict}")
KeyError: "'action' not found in tensordict TensorDict(\n    fields={\n        collector: TensorDict(\n            fields={\n                traj_ids: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},\n            batch_size=torch.Size([]),\n            device=cpu,\n            is_shared=False),\n        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n        is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n        pixels: Tensor(shape=torch.Size([3, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),\n        pixels_transformed: Tensor(shape=torch.Size([3, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False),\n        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),\n        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},\n    batch_size=torch.Size([]),\n    device=cpu,\n    is_shared=False)"

Expected behavior

We should be able to yield with init_random_frames = 0

System info

Describe the characteristic of your environment:

Reason and Possible fixes

It seems like self._policy_output_keys from SyncDataCollector::_make_final_rollout is set to {} when init_random_frames=0 which causes an unwanted behavior in SyncDataCollector::rollout.
More precisely, these lines from SyncDataCollector::rollout :

policy_output = self.policy(policy_input)
if self._shuttle is not policy_output:
    # ad-hoc update shuttle
    self._shuttle.update(
        policy_output, keys_to_update=self._policy_output_keys
    )

In my case, policy_output was a tensor with the action key, but since self._policy_output_keys is {}, this means that self._shuttle is never updated to have the action key. This causes a crash with the error KeyError: "'action' not found in tensordict

Checklist

vmoens commented 4 weeks ago

On it!