Closed vmoens closed 2 months ago
Note: Links to docs will display an error until the docs builds have been completed.
As of commit 99f5dccd8f84b9cd67f2cd5d002544abd933dcc3 with merge base 59d2ae1ec0294043bf3e808c81907d9f53796303 ():
* [Habitat Tests on Linux / tests (3.9, 12.1) / linux-job](https://hud.pytorch.org/pr/pytorch/rl/2351#28286588339) ([gh](https://github.com/pytorch/rl/actions/runs/10222290306/job/28286588339)) `RuntimeError: Command docker exec -t 033f6fd1c3de18a70660e60618d66c2e7396cab1814c3343fff5b670aea3970f /exec failed with exit code 139`
* [Build Windows Wheels / pytorch/rl (pytorch/rl, python packaging/wheel/relocate.py, test/smoke_test.py, torchrl) / upload / wheel-py3_9-cuda11_8](https://hud.pytorch.org/pr/pytorch/rl/2351#28293382738) ([gh](https://github.com/pytorch/rl/actions/runs/10222290355/job/28293382738)) ([similar failure](https://hud.pytorch.org/pytorch/rl/commit/99f5dccd8f84b9cd67f2cd5d002544abd933dcc3#28293382812)) `Unable to find any artifacts for the associated workflow`
👉 Rebase onto the `viable/strict` branch to avoid these failures
* [Build Windows Wheels / pytorch/rl (pytorch/rl, python packaging/wheel/relocate.py, test/smoke_test.py, torchrl) / upload / wheel-py3_9-cpu](https://hud.pytorch.org/pr/pytorch/rl/2351#28293382662) ([gh](https://github.com/pytorch/rl/actions/runs/10222290355/job/28293382662)) ([trunk failure](https://hud.pytorch.org/pytorch/rl/commit/59d2ae1ec0294043bf3e808c81907d9f53796303#28287579986)) `Unable to find any artifacts for the associated workflow` * [Build Windows Wheels / pytorch/rl (pytorch/rl, python packaging/wheel/relocate.py, test/smoke_test.py, torchrl) / upload / wheel-py3_9-cuda12_1](https://hud.pytorch.org/pr/pytorch/rl/2351#28293382812) ([gh](https://github.com/pytorch/rl/actions/runs/10222290355/job/28293382812)) ([trunk failure](https://hud.pytorch.org/pytorch/rl/commit/59d2ae1ec0294043bf3e808c81907d9f53796303#28287580839)) `Unable to find any artifacts for the associated workflow` * [Build Windows Wheels / pytorch/rl (pytorch/rl, python packaging/wheel/relocate.py, test/smoke_test.py, torchrl) / upload / wheel-py3_9-cuda12_4](https://hud.pytorch.org/pr/pytorch/rl/2351#28293382882) ([gh](https://github.com/pytorch/rl/actions/runs/10222290355/job/28293382882)) ([trunk failure](https://hud.pytorch.org/pytorch/rl/commit/59d2ae1ec0294043bf3e808c81907d9f53796303#28287581079)) `Unable to find any artifacts for the associated workflow`
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Could you explain why we need this?
also, having 2 copies of the parameters is not error prone?
for example in methods like https://github.com/facebookresearch/BenchMARL/blob/d260eea5d4ef2ff5f0bea8ae36f68638ecb14865/benchmarl/models/common.py#L165 or in any general case where users access self.parameters()
won’t things break?
We test that nothing breaks. I don't thing it's error prone, you never see two copies (for instance parameters()
just returns one).
We need this because it makes initialization of the params more natural, mainly.
So if a user modifies the content of one copy of the parameters, the change is reflected in the other copy? As in the function I sent.
But apart from being more natural, what use cases is it used for/ envisioned for?
Maybe I am misreading the PR description: when you say 2 copies you mean:
So if a user modifies the content of one copy of the parameters, the change is reflected in the other copy? As in the function I sent.
They are exactly the same objects, just one is in self.params and not seen by self.modules()
or self.parameters()
and the other is in self._empty_net
.
But apart from being more natural, what use cases is it used for/ envisioned for?
Many people are used to do
def init(module):
if isinstance(module, nn.Linear):
self.weight.data.zero_()
self.apply(init)
which you can only do if the params are in the module, not in the TDParams. Moreover TDParams carries some overhead. The new version should be faster. On top of that it's totally optional and 100% non-bc breaking
Description
We currently store the parameters in MARL modules in
self.params
in a TensorDictParams. During a call to forward, we callvmap
andto_module
to put the batched parameters in place within the module.This PR proposes to optionally make
self.params
a regular TensorDict (ie,self.parameters()
will not see them becauseself.params
is not within theself.modules()
anymore), and place them in theself._empty_net
instead. With that in place, the module has two copies of the parameters, but one is not accessible viaself.parameters()
(so things don't change from the user perspective).We test that these two scenarios are identical and that sending the module to device does not create multiple distinct copies of the params.
cc @matteobettini