pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2k stars 268 forks source link

[BUG]MARL PPO render fails in observation deepcopy #1509

Closed Quinticx closed 10 months ago

Quinticx commented 10 months ago

Describe the bug

Following the tutorial on Multi-agent reinforcement learning using PPO, when attempting to render the rollout after training, it fails to do so.

To Reproduce

I implemented the code from the MARL PPO tutorial (https://pytorch.org/rl/tutorials/multiagent_ppo.html#). The issue/bug is only encountered when attempting to render the environment after training the policy. If the policy=policy line is commented it runs smoothly, but has no information regarding the policy and thus is random.

env.rollout(
    max_steps=max_steps,
    policy=policy,
    callback=lambda env, _: env.render(),
    auto_cast_to_device=True,
    break_when_any_done=False,
)
Traceback (most recent call last):
  File "/home/with0024/PycharmProjects/MARL_Example/main.py", line 186, in <module>
    env.rollout(
  File "/home/with0024/PycharmProjects/MARL_Example/venv/lib/python3.10/site-packages/torchrl/envs/common.py", line 1572, in rollout
    tensordict = self.step(tensordict)
  File "/home/with0024/PycharmProjects/MARL_Example/venv/lib/python3.10/site-packages/torchrl/envs/common.py", line 1136, in step
    next_tensordict = self._step(tensordict)
  File "/home/with0024/PycharmProjects/MARL_Example/venv/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py", line 638, in _step
    next_tensordict = self.base_env._step(tensordict_in)
  File "/home/with0024/PycharmProjects/MARL_Example/venv/lib/python3.10/site-packages/torchrl/envs/libs/vmas.py", line 335, in _step
    obs, rews, dones, infos = self._env.step(action)
  File "/home/with0024/PycharmProjects/MARL_Example/venv/lib/python3.10/site-packages/vmas/simulator/environment/environment.py", line 231, in step
    obs, rewards, dones, infos = self.get_from_scenario(
  File "/home/with0024/PycharmProjects/MARL_Example/venv/lib/python3.10/site-packages/vmas/simulator/environment/environment.py", line 149, in get_from_scenario
    observation = copy.deepcopy(self.scenario.observation(agent))
  File "/usr/lib/python3.10/copy.py", line 153, in deepcopy
    y = copier(memo)
  File "/home/with0024/PycharmProjects/MARL_Example/venv/lib/python3.10/site-packages/torch/_tensor.py", line 86, in __deepcopy__
    raise RuntimeError(
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment.  If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see https://github.com/pytorch/pytorch/pull/103001

Reason and Possible fixes

The traceback links to the following issue in PyTorch: https://github.com/pytorch/pytorch/pull/103001

System info

Describe the characteristics of your environment:

Checklist

matteobettini commented 10 months ago

Thanks for reporting, I am looking into this

matteobettini commented 10 months ago

Could you try wrapping the render function in torch.no_grad() like below and tell me if the problem still occurs?

with torch.no_grad():
    env.rollout(
        max_steps=max_steps,
        policy=policy,
        callback=lambda env, _: env.render(),
        auto_cast_to_device=True,
        break_when_any_done=False,
    )
vmoens commented 10 months ago

Thanks for reporting this @Quinticx!

As @matteobettini suggested, a no_grad could solve things.

To give a bit of context, gradient propagation in rollouts (and in RL in general) is a tough decision to make. For instance, we could disable gradients for all rollout calls but that would be a terrible decision for meta-rl, inverse-rl, trajectory optimization and such. We could add one more kwarg in rollout, but the advantage compared to explicitly putting your call under a no_grad decorator would be marginal IMO. We have similar issues with value computation (when and how disable graph construction) and many other places in the code. There isn't a single answer I'm afraid. What we could do better is capturing errors related to that.

As always, suggestions are welcome!