pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.3k stars 306 forks source link

[BUG] Collector deletes new tensordict keys #1625

Open Acciorocketships opened 1 year ago

Acciorocketships commented 1 year ago

New Tensordict keys created in policy modules used in the collector are not returned in the rollout tensordict.

The issue is in this line, as setting "out" for torch.stack deletes any new entries which have been created: https://github.com/pytorch/rl/blob/70c650ec8c946f36fd8d57c11612548da2251128/torchrl/collectors/collectors.py#L877

I did find a workaround, which is to use TensorDictPrimer to create a TransformedEnv that has the new keys that you wish to be saved.

vmoens commented 1 year ago

Yes this is the "official" way of doing it. But we could do it automatically by running the policy once, checking the extra keys and adding them with primer. UX will be only positively impacted I think!

matteobettini commented 1 year ago

Yes this is the "official" way of doing it. But we could do it automatically by running the policy once, checking the extra keys and adding them with primer. UX will be only positively impacted I think!

With the recent change to collector init we already do this https://github.com/pytorch/rl/blob/c7d4764e787e4be903f7b5f03b6008f00e9b23a1/torchrl/collectors/collectors.py#L670 (only in the case when a policy spec is not found)

We could add a flag to ask to do this all the time. (Or just do it all the time anyway?)

vmoens commented 1 year ago

Do we actually want to return all the intermediate values? Imagine the following use case: I programmatically write a module

mods = []
for i in range(10):
    mods.append(TensorDictModule(make_layer, in_keys=[f"data_{i}"], out_keys=[f"data_{i+1}"]))
policy = TensorDictSequential(mods)

My dataset will be full of intermediate values. If you put an underscore in front of the keys they will be considered as private and discarded, but still this may be a bit dangerous. (this is an extreme example not to be taken at face value, my point is just that intermediate values in a TDModule may be irrelevant for training). The idea was originally that the policy construction was the step responsible for indicating what you wanted to see in your dataset.

I understand that the UX can be negatively impacted by this choice, and the recent change pointed by @matteobettini is actually more an inconstency than anything else (since we don't precisely say: if you want the extra entries to be returned, do not put a spec in your policy -- which would be a weird advice to give).

So my current take is this:

matteobettini commented 1 year ago

I am aligned with this. I think controlling whether to get all keys from the policy run on reset data via a flag is a nice addition.

but what would happen with that flag false (the default you are proposing) and no spec? Fallback to true? Error?

Also another reason why i preferred it true is because i often find myself accessing intermediate values such as logits.

c3-utsavdutta98 commented 4 months ago

@matteobettini , @vmoens What is the difference between using set_info_dict_reader and TensorDictPrimer in this context? Are they the same, in terms of passing around data across the stepped tensordicts.

Additionally, what does adding this key into the observation spec do, does that also tell the collector to pass these keys on?

Would really appreciate any insight into these 3 approaches^

c3-utsavdutta98 commented 4 months ago

make_final_rollout has the following lines, which seem like adding these into the observation spec would ensure they are passed on? Is there any harm in adding all of this 'bloat' into the observation spec?

for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]:
         _env_output_keys += list(self.env.output_spec[spec].keys(True, True))
        self._env_output_keys = _env_output_keys