pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.19k stars 289 forks source link

[BUG] Issues with `TensorDictPrimer` #2327

Closed matteobettini closed 1 month ago

matteobettini commented 1 month ago

Without the primer, the collector does not feed any hidden state to the policy

in the RNN tutorial it is stated that the primer is optional and it is used just to store the hidden states in the buffer.

This is not true in practice. Not adding the primer will result in the collector not feeding the hidden states to the policy during execution. Which will silently cause the rnn to loose any recurrency.

To reproduce, comment out this line

https://github.com/pytorch/rl/blob/0063741839a3e5e1a527947945494d54f91bc629/tutorials/sphinx-tutorials/dqn_with_rnn.py#L269

and print the policy input at this line

https://github.com/pytorch/rl/blob/0063741839a3e5e1a527947945494d54f91bc629/torchrl/collectors/collectors.py#L733

you will see that no hidden state is fed to the rnn during execution and no errors or warnings are thrown

The primer overwrites any nested spec

Consider an env with nested specs

 env = VmasEnv(
        scenario="balance,
        num_envs=5,
    )

add to it a primer for a nested hidden state

    env = TransformedEnv(
        env,
        TensorDictPrimer(
            {
                "agents": CompositeSpec(
                    {
                        "h": UnboundedContinuousTensorSpec(
                            shape=(*env.shape, env.n_agents, 2, 128)
                        )
                    },
                    shape=(*env.shape, env.n_agents),
                )
            }
        ),
    )

the primer code in https://github.com/pytorch/rl/blob/0063741839a3e5e1a527947945494d54f91bc629/torchrl/envs/transforms/transforms.py#L4649 will overwirite the observation spec instead of updating it, resulting in the loss of all the spec keys that previoulsy were in the "agents" spec

The same result is obtained with

    env = TransformedEnv(
        env,
        TensorDictPrimer(
            {
               ("agents","h"): UnboundedContinuousTensorSpec(
                            shape=(*env.shape, env.n_agents, 2, 128)
                )
            }
        ),
    )

here, updating the spec instead of overwriting it should do the job

The order of the primer in the transforms seems to have an impact

In the same vmas environemnt as above, if i put the primer and then the reward sum

 env = TransformedEnv(
        env,
        Compose(
            TensorDictPrimer(
                {
                    "agents": CompositeSpec(
                        {
                            "h": UnboundedContinuousTensorSpec(
                                shape=(*env.shape, env.n_agents, 2, 128)
                            )
                        },
                        shape=(*env.shape, env.n_agents),
                    )
                }
            ),
           RewardSum(
                        in_keys=[env.reward_key],
                        out_keys=[("agents", "episode_reward")],
                    ),
        ),
    )

all works well

but the opposite

 env = TransformedEnv(
        env,
        Compose(
            RewardSum(
                in_keys=[env.reward_key],
                out_keys=[("agents", "episode_reward")],
            ),
            TensorDictPrimer(
                {
                    "agents": CompositeSpec(
                        {
                            "h": UnboundedContinuousTensorSpec(
                                shape=(*env.shape, env.n_agents, 2, 128)
                            )
                        },
                        shape=(*env.shape, env.n_agents),
                    )
                }
            ),
        ),
    )

causes

Traceback (most recent call last):
  File "/Users/Matteo/PycharmProjects/torchrl/sota-implementations/multiagent/mappo_ippo.py", line 302, in train
    collector = SyncDataCollector(
                ^^^^^^^^^^^^^^^^^^
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/collectors/collectors.py", line 644, in __init__
    self._make_shuttle()
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/collectors/collectors.py", line 661, in _make_shuttle
    self._shuttle = self.env.reset()
                    ^^^^^^^^^^^^^^^^
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/common.py", line 2143, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/transforms/transforms.py", line 814, in _reset
    tensordict_reset = self.transform._reset(tensordict, tensordict_reset)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/transforms/transforms.py", line 1129, in _reset
    tensordict_reset = t._reset(tensordict, tensordict_reset)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/transforms/transforms.py", line 4722, in _reset
    value = self.default_value[key]
            ~~~~~~~~~~~~~~~~~~^^^^^
KeyError: ('agents', 'episode_reward')
matteobettini commented 1 month ago

For the first issue, I think we should go for a solution where the Primer becomes optional and needed only if users want the hidden states in the collection buffer.

But without the primer users should still be able to use rnns in collectors with the logic that anything coming out of step_mdp is refed to the policy

vmoens commented 1 month ago

Since there are multiple issues I'd suggest to open a tracker. I'll comment on the first here: it's optional in the sense that you can make the env run without primer if no other module is involved. If a ParallelEnv or a collector is used things will indeed break.

matteobettini commented 1 month ago

Yeah I'll eventually spread them into separate issues.

If that is the case regarding the first, I suggest we make it extra clear that in the tutorial the Primer is not optional, as we are using a collector.

In general, do we really have no way to make the collector work without the primer? it would be nice to have it optional in collectors for users that do not want the hidden states as part of the output buffer

matteobettini commented 1 month ago

Adding to my previous comment, I think to solve the first issue we could add the output of the policy (looking at the next key) to the shuttle

https://github.com/pytorch/rl/blob/371181cb067a4cd0456b2016d9035140bbb2adae/torchrl/collectors/collectors.py#L733

I would really like to not use the Primer as it is a huge pain in large projects

vmoens commented 1 month ago

I don't see a way of not having a primer, we need to let the env know about extra entries in the tensordict. Is that such a "huge" pain though? We've worked hard with @albertbou92 to provide all the tooling to make this work as seamlessly as possible.

albertbou92 commented 1 month ago

In the collector, we could automatically check if any primer is missing and append it. Or raise a warning. We can extract the expected primers from the actor.

I don't find it that inconvenient to use the primer. I simply got used to adding:

if primers := get_primers_from_module(actor):
    env.append_transform(primers)

However, if a user is not aware or does not remember to add the primer for some reason, silently not using recurrency can cause a lot of headaches.

matteobettini commented 1 month ago

What is the technical limitation that is preventing us from reading the hidden state from the policy output? It seems to me that since we are running the policy at collector init time, its outputs in the "next" tensordict could be captured and accounted for during the collector rollout (aka move them to the policy input at every step_mdp).

With this you would maybe loose the possibility of obtaining zeroed states after resets, but the absence of the hidden state (i.e. it being none) should be possible to handle in the rnn.

The pain I am referring to is that the Primer in the multiagent setting will require to know: the number of agents, the group names, the hidden sizes. All this information is not immediately available. Plus having it optional would make approach from new users easier IMO and less bug prone

vmoens commented 1 month ago

Sorry to ask again, can we split this issue? I'd like to close the pieces that are solved

matteobettini commented 1 month ago

Yeah we solved everything apart the first one, I ll make one for that