Open skandermoalla opened 1 year ago
Good point there Regarding reshaping: you should reshape and refine_names, I believe the last dim will still be time-compliant (but you need to make sure you have truncated signals at the end of each time step) Other than that, we could consider falling back on vmap / first-class-dimensions whenever this situation is encountered. I will give it a look and ping you once it's on its way, as usual.
@vmoens in some cases the env data may have an arbitrary batch size (*B) before the time dimension.
Is the current approach, before we land smth like https://github.com/pytorch-labs/tensordict/pull/525, to try to flatten all these dims into one making sure to add terminations when doing so?
I don't think so, as I said in my answer the proper approach should be to vmap over the leading dims up to the time dim. Wdyt?
Somehow In the PPO example, the advantage module is called on the rollout batch shape https://github.com/pytorch/rl/blob/147de71d090d5705182bfabd24a99f3b2ee4cec9/examples/ppo/ppo.py#L103 and doesn't crash with the conv2d complaining.
I also managed to reproduce this with the ConvNet
and MLP
modules of PyTorch RL and my advantage module now runs without reshaping.
I'm sending more details to compare the settings.
Okay, so the ConvNet
of TorchRL actually flattens the batch before running a forward and then unflattens it back.
Maybe this could be made clearer to the user so that when designing custom models they know that they have to do something similar.
Otherwise, vmap
ing would be the way to go. I'm just concerned about memory requirements compared to flattening the tensordict.
Otherwise, vmaping would be the way to go. I'm just concerned about memory requirements compared to flattening the tensordict.
@skandermoalla Looking back at this comment, I wonder why vmap
should have higher mem requirements?
I'm not very familiar with vmap
, but does the memory taken by the model weights stay the same when you vmap
it?
Describe the bug
When you get a tensordict rollout of shape
(N_envs, N_steps, C, H, W)
out of a collector and you want to apply an advantage module that starts withconv2d
layers:conv2d
layer complaining about the input size e.g.RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [2, 128, 4, 84, 84]
rollout.reshape(-1)
so that it has shape[B, C, H, W]
and then calling the advantage module will run but issue the warningtorchrl/objectives/value/advantages.py:99: UserWarning: Got a tensordict without a time-marked dimension, assuming time is along the last dimension.
leaving you unsure of wether the advantages were computed correctly.So it's not clear how one should proceed.