Closed matteobettini closed 1 month ago
part of #2327
Consider an env with nested specs
env = VmasEnv( scenario="balance, num_envs=5, )
add to it a primer for a nested hidden state
env = TransformedEnv( env, TensorDictPrimer( { "agents": CompositeSpec( { "h": UnboundedContinuousTensorSpec( shape=(*env.shape, env.n_agents, 2, 128) ) }, shape=(*env.shape, env.n_agents), ) } ), )
the primer code in https://github.com/pytorch/rl/blob/0063741839a3e5e1a527947945494d54f91bc629/torchrl/envs/transforms/transforms.py#L4649 will overwirite the observation spec instead of updating it, resulting in the loss of all the spec keys that previoulsy were in the "agents" spec
The same result is obtained with
env = TransformedEnv( env, TensorDictPrimer( { ("agents","h"): UnboundedContinuousTensorSpec( shape=(*env.shape, env.n_agents, 2, 128) ) } ), )
here, updating the spec instead of overwriting it should do the job
part of #2327
The primer overwrites any nested spec
Consider an env with nested specs
add to it a primer for a nested hidden state
the primer code in https://github.com/pytorch/rl/blob/0063741839a3e5e1a527947945494d54f91bc629/torchrl/envs/transforms/transforms.py#L4649 will overwirite the observation spec instead of updating it, resulting in the loss of all the spec keys that previoulsy were in the "agents" spec
The same result is obtained with
here, updating the spec instead of overwriting it should do the job