pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.19k stars 289 forks source link

[BUG] `TensorDictPrimer` overwrites nested specs in the environemnt #2331

Closed matteobettini closed 1 month ago

matteobettini commented 1 month ago

part of #2327

The primer overwrites any nested spec

Consider an env with nested specs

 env = VmasEnv(
        scenario="balance,
        num_envs=5,
    )

add to it a primer for a nested hidden state

    env = TransformedEnv(
        env,
        TensorDictPrimer(
            {
                "agents": CompositeSpec(
                    {
                        "h": UnboundedContinuousTensorSpec(
                            shape=(*env.shape, env.n_agents, 2, 128)
                        )
                    },
                    shape=(*env.shape, env.n_agents),
                )
            }
        ),
    )

the primer code in https://github.com/pytorch/rl/blob/0063741839a3e5e1a527947945494d54f91bc629/torchrl/envs/transforms/transforms.py#L4649 will overwirite the observation spec instead of updating it, resulting in the loss of all the spec keys that previoulsy were in the "agents" spec

The same result is obtained with

    env = TransformedEnv(
        env,
        TensorDictPrimer(
            {
               ("agents","h"): UnboundedContinuousTensorSpec(
                            shape=(*env.shape, env.n_agents, 2, 128)
                )
            }
        ),
    )

here, updating the spec instead of overwriting it should do the job