pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.2k stars 290 forks source link

[BUG] It's not clear how to call an advantage module with batched envs and pixel observations. #1522

Open skandermoalla opened 1 year ago

skandermoalla commented 1 year ago

Describe the bug

When you get a tensordict rollout of shape (N_envs, N_steps, C, H, W) out of a collector and you want to apply an advantage module that starts with conv2d layers:

  1. directly applying the module will crash with the conv2d layer complaining about the input size e.g. RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [2, 128, 4, 84, 84]
  2. flattening the tensordict first with rollout.reshape(-1) so that it has shape [B, C, H, W] and then calling the advantage module will run but issue the warning torchrl/objectives/value/advantages.py:99: UserWarning: Got a tensordict without a time-marked dimension, assuming time is along the last dimension. leaving you unsure of wether the advantages were computed correctly.

So it's not clear how one should proceed.

vmoens commented 12 months ago

Good point there Regarding reshaping: you should reshape and refine_names, I believe the last dim will still be time-compliant (but you need to make sure you have truncated signals at the end of each time step) Other than that, we could consider falling back on vmap / first-class-dimensions whenever this situation is encountered. I will give it a look and ping you once it's on its way, as usual.

matteobettini commented 12 months ago

@vmoens in some cases the env data may have an arbitrary batch size (*B) before the time dimension.

Is the current approach, before we land smth like https://github.com/pytorch-labs/tensordict/pull/525, to try to flatten all these dims into one making sure to add terminations when doing so?

vmoens commented 12 months ago

I don't think so, as I said in my answer the proper approach should be to vmap over the leading dims up to the time dim. Wdyt?

skandermoalla commented 12 months ago

Somehow In the PPO example, the advantage module is called on the rollout batch shape https://github.com/pytorch/rl/blob/147de71d090d5705182bfabd24a99f3b2ee4cec9/examples/ppo/ppo.py#L103 and doesn't crash with the conv2d complaining.

https://github.com/pytorch/rl/blob/147de71d090d5705182bfabd24a99f3b2ee4cec9/examples/ppo/utils.py#L341

I also managed to reproduce this with the ConvNet and MLP modules of PyTorch RL and my advantage module now runs without reshaping.

I'm sending more details to compare the settings.

skandermoalla commented 12 months ago

Okay, so the ConvNet of TorchRL actually flattens the batch before running a forward and then unflattens it back.

https://github.com/pytorch/rl/blob/147de71d090d5705182bfabd24a99f3b2ee4cec9/torchrl/modules/models/models.py#L479

Maybe this could be made clearer to the user so that when designing custom models they know that they have to do something similar.

Otherwise, vmaping would be the way to go. I'm just concerned about memory requirements compared to flattening the tensordict.

vmoens commented 6 months ago

Otherwise, vmaping would be the way to go. I'm just concerned about memory requirements compared to flattening the tensordict.

@skandermoalla Looking back at this comment, I wonder why vmap should have higher mem requirements?

skandermoalla commented 5 months ago

I'm not very familiar with vmap, but does the memory taken by the model weights stay the same when you vmap it?