pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.29k stars 305 forks source link

[Feature Request] Pass additional information into the tensordict returned by the rollout method (EnvBase) #1808

Closed felix-basiliskroko closed 9 months ago

felix-basiliskroko commented 9 months ago

Motivation

I am currently working on a PPO algorithm for a pursuit-evasion game using a custom gymnasium environment. My observation space mainly consists of relative metrics so, during evaluation, when I try to execute the trained policy without exploration mode using the rollout method of the TransformedEnv (EnvBase) class, the information returned by the tensordict object misses critical additional information , like in my case, the absolute positions of evader and pursuer.

Solution

An additional field in the tensordict object returned by the rollout method named "info" or similar. It should contain the contents of the variable returned by the step function of the environment also named "info", thereby introducing some new degree of freedom during inference:

return observation, reward, terminated, False, **info**

Alternatives

-

Additional context

-

Checklist

vmoens commented 9 months ago

Can you give some more details about where that info comes from? Is it in the info dict? Are you using gym? We do pass the info in the tensordict using info_dict_reader

If you give me more details about how you are creating the env I can guide you on how to retrieve these info

felix-basiliskroko commented 9 months ago

Yes, I use a gymnasium environment and this torchrl tutorial is basically where I got my main loop from.

def _get_info(self): return {"dummy_info": self._agent.get_position()}

The info I would want to include in the tensordict during inference/evaluation is supposed to come from this method inside my custom gymnasium environment (which is also called in the step function). For reference, the basic structure of my environment is from the official gymnasium documentation. I was hoping the 'info' (returned from the function above) to be inside this 'eval_rollout' tensordict as a seperate field, similar to 'actions' or 'observations'.

eval_rollout = env.rollout(1000, policy_module)

I hope that clarifies some things. Thanks alot for your help, I really appreciate it. Let me know if there's anything else you need to know.

vmoens commented 9 months ago

I made some improvements to this API in #1809

After I merge this PR the only thing you will need to do is call env = env.auto_register_info_dict() and all your info will be registered within the tensordict output.

We can add an option to register it all under an "info" key if necessary, though one "advantage" of tensordict is that you don't really need to put things a separate container (but I can understand that one would wish to save those extra data apart)

env = GymEnv("HalfCheetah-v4") 
env = env.auto_register_info_dict()

should work OOB