pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.25k stars 297 forks source link

[Feature] Store MARL parameters in module #2351

Closed vmoens closed 2 months ago

vmoens commented 2 months ago

Description

We currently store the parameters in MARL modules in self.params in a TensorDictParams. During a call to forward, we call vmap and to_module to put the batched parameters in place within the module.

This PR proposes to optionally make self.params a regular TensorDict (ie, self.parameters() will not see them because self.params is not within the self.modules() anymore), and place them in the self._empty_net instead. With that in place, the module has two copies of the parameters, but one is not accessible via self.parameters() (so things don't change from the user perspective).

We test that these two scenarios are identical and that sending the module to device does not create multiple distinct copies of the params.

cc @matteobettini

pytorch-bot[bot] commented 2 months ago

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2351

Note: Links to docs will display an error until the docs builds have been completed.

:x: 1 New Failure, 4 Unrelated Failures

As of commit 99f5dccd8f84b9cd67f2cd5d002544abd933dcc3 with merge base 59d2ae1ec0294043bf3e808c81907d9f53796303 (image):

NEW FAILURE - The following job has failed:

* [Habitat Tests on Linux / tests (3.9, 12.1) / linux-job](https://hud.pytorch.org/pr/pytorch/rl/2351#28286588339) ([gh](https://github.com/pytorch/rl/actions/runs/10222290306/job/28286588339)) `RuntimeError: Command docker exec -t 033f6fd1c3de18a70660e60618d66c2e7396cab1814c3343fff5b670aea3970f /exec failed with exit code 139`

FLAKY - The following job failed but was likely due to flakiness present on trunk:

* [Build Windows Wheels / pytorch/rl (pytorch/rl, python packaging/wheel/relocate.py, test/smoke_test.py, torchrl) / upload / wheel-py3_9-cuda11_8](https://hud.pytorch.org/pr/pytorch/rl/2351#28293382738) ([gh](https://github.com/pytorch/rl/actions/runs/10222290355/job/28293382738)) ([similar failure](https://hud.pytorch.org/pytorch/rl/commit/99f5dccd8f84b9cd67f2cd5d002544abd933dcc3#28293382812)) `Unable to find any artifacts for the associated workflow`

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

* [Build Windows Wheels / pytorch/rl (pytorch/rl, python packaging/wheel/relocate.py, test/smoke_test.py, torchrl) / upload / wheel-py3_9-cpu](https://hud.pytorch.org/pr/pytorch/rl/2351#28293382662) ([gh](https://github.com/pytorch/rl/actions/runs/10222290355/job/28293382662)) ([trunk failure](https://hud.pytorch.org/pytorch/rl/commit/59d2ae1ec0294043bf3e808c81907d9f53796303#28287579986)) `Unable to find any artifacts for the associated workflow` * [Build Windows Wheels / pytorch/rl (pytorch/rl, python packaging/wheel/relocate.py, test/smoke_test.py, torchrl) / upload / wheel-py3_9-cuda12_1](https://hud.pytorch.org/pr/pytorch/rl/2351#28293382812) ([gh](https://github.com/pytorch/rl/actions/runs/10222290355/job/28293382812)) ([trunk failure](https://hud.pytorch.org/pytorch/rl/commit/59d2ae1ec0294043bf3e808c81907d9f53796303#28287580839)) `Unable to find any artifacts for the associated workflow` * [Build Windows Wheels / pytorch/rl (pytorch/rl, python packaging/wheel/relocate.py, test/smoke_test.py, torchrl) / upload / wheel-py3_9-cuda12_4](https://hud.pytorch.org/pr/pytorch/rl/2351#28293382882) ([gh](https://github.com/pytorch/rl/actions/runs/10222290355/job/28293382882)) ([trunk failure](https://hud.pytorch.org/pytorch/rl/commit/59d2ae1ec0294043bf3e808c81907d9f53796303#28287581079)) `Unable to find any artifacts for the associated workflow`

This comment was automatically generated by Dr. CI and updates every 15 minutes.

github-actions[bot] commented 2 months ago

$\color{#D29922}\textsf{\Large\⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 91. Improved: $\large\color{#35bf28}7$. Worsened: $\large\color{#d91a1a}1$.

Expand to view detailed results | Name | Max | Mean | Ops | Ops on Repo `HEAD` | Change | | ----------------------------------------------------------------------------------------- | --------- | --------- | --------------- | ------------------ | ----------------------------------- | | test_single | 60.8594ms | 58.0787ms | 17.2180 Ops/s | 16.9606 Ops/s | $\color{#35bf28}+1.52\\%$ | | test_sync | 33.4960ms | 31.6156ms | 31.6300 Ops/s | 30.9966 Ops/s | $\color{#35bf28}+2.04\\%$ | | test_async | 53.9812ms | 30.1618ms | 33.1545 Ops/s | 33.1318 Ops/s | $\color{#35bf28}+0.07\\%$ | | test_simple | 0.4764s | 0.4100s | 2.4393 Ops/s | 2.3884 Ops/s | $\color{#35bf28}+2.13\\%$ | | test_transformed | 0.6281s | 0.5667s | 1.7647 Ops/s | 1.7258 Ops/s | $\color{#35bf28}+2.25\\%$ | | test_serial | 1.3178s | 1.2514s | 0.7991 Ops/s | 0.7892 Ops/s | $\color{#35bf28}+1.25\\%$ | | test_parallel | 1.1867s | 1.1220s | 0.8913 Ops/s | 0.8981 Ops/s | $\color{#d91a1a}-0.76\\%$ | | test_step_mdp_speed[True-True-True-True-True] | 77.6570μs | 23.8122μs | 41.9952 KOps/s | 41.0001 KOps/s | $\color{#35bf28}+2.43\\%$ | | test_step_mdp_speed[True-True-True-True-False] | 40.1650μs | 13.9823μs | 71.5188 KOps/s | 69.3805 KOps/s | $\color{#35bf28}+3.08\\%$ | | test_step_mdp_speed[True-True-True-False-True] | 49.6730μs | 13.7194μs | 72.8893 KOps/s | 70.5843 KOps/s | $\color{#35bf28}+3.27\\%$ | | test_step_mdp_speed[True-True-True-False-False] | 56.1460μs | 8.1598μs | 122.5516 KOps/s | 120.6058 KOps/s | $\color{#35bf28}+1.61\\%$ | | test_step_mdp_speed[True-True-False-True-True] | 80.8810μs | 25.4408μs | 39.3070 KOps/s | 38.7792 KOps/s | $\color{#35bf28}+1.36\\%$ | | test_step_mdp_speed[True-True-False-True-False] | 60.3230μs | 15.4593μs | 64.6861 KOps/s | 63.6353 KOps/s | $\color{#35bf28}+1.65\\%$ | | test_step_mdp_speed[True-True-False-False-True] | 48.6910μs | 15.2752μs | 65.4657 KOps/s | 63.9940 KOps/s | $\color{#35bf28}+2.30\\%$ | | test_step_mdp_speed[True-True-False-False-False] | 45.5450μs | 9.5504μs | 104.7072 KOps/s | 103.2983 KOps/s | $\color{#35bf28}+1.36\\%$ | | test_step_mdp_speed[True-False-True-True-True] | 60.5730μs | 27.3040μs | 36.6246 KOps/s | 36.4190 KOps/s | $\color{#35bf28}+0.56\\%$ | | test_step_mdp_speed[True-False-True-True-False] | 55.2530μs | 17.0829μs | 58.5379 KOps/s | 58.0164 KOps/s | $\color{#35bf28}+0.90\\%$ | | test_step_mdp_speed[True-False-True-False-True] | 41.0870μs | 15.3451μs | 65.1673 KOps/s | 64.4278 KOps/s | $\color{#35bf28}+1.15\\%$ | | test_step_mdp_speed[True-False-True-False-False] | 36.3380μs | 9.5931μs | 104.2420 KOps/s | 103.3176 KOps/s | $\color{#35bf28}+0.89\\%$ | | test_step_mdp_speed[True-False-False-True-True] | 57.4970μs | 28.6329μs | 34.9248 KOps/s | 35.0743 KOps/s | $\color{#d91a1a}-0.43\\%$ | | test_step_mdp_speed[True-False-False-True-False] | 53.9300μs | 18.4106μs | 54.3166 KOps/s | 53.9439 KOps/s | $\color{#35bf28}+0.69\\%$ | | test_step_mdp_speed[True-False-False-False-True] | 62.1160μs | 16.5997μs | 60.2419 KOps/s | 59.7351 KOps/s | $\color{#35bf28}+0.85\\%$ | | test_step_mdp_speed[True-False-False-False-False] | 41.6880μs | 10.8472μs | 92.1901 KOps/s | 89.4092 KOps/s | $\color{#35bf28}+3.11\\%$ | | test_step_mdp_speed[False-True-True-True-True] | 93.3140μs | 26.8498μs | 37.2443 KOps/s | 36.4847 KOps/s | $\color{#35bf28}+2.08\\%$ | | test_step_mdp_speed[False-True-True-True-False] | 44.5730μs | 16.9294μs | 59.0689 KOps/s | 57.6231 KOps/s | $\color{#35bf28}+2.51\\%$ | | test_step_mdp_speed[False-True-True-False-True] | 49.7330μs | 17.5997μs | 56.8190 KOps/s | 54.6237 KOps/s | $\color{#35bf28}+4.02\\%$ | | test_step_mdp_speed[False-True-True-False-False] | 33.6140μs | 10.7646μs | 92.8974 KOps/s | 90.8619 KOps/s | $\color{#35bf28}+2.24\\%$ | | test_step_mdp_speed[False-True-False-True-True] | 55.6250μs | 28.6919μs | 34.8531 KOps/s | 34.6712 KOps/s | $\color{#35bf28}+0.52\\%$ | | test_step_mdp_speed[False-True-False-True-False] | 44.8740μs | 18.6361μs | 53.6594 KOps/s | 53.4642 KOps/s | $\color{#35bf28}+0.37\\%$ | | test_step_mdp_speed[False-True-False-False-True] | 44.2230μs | 19.0267μs | 52.5576 KOps/s | 51.1201 KOps/s | $\color{#35bf28}+2.81\\%$ | | test_step_mdp_speed[False-True-False-False-False] | 64.9710μs | 12.0870μs | 82.7335 KOps/s | 81.0897 KOps/s | $\color{#35bf28}+2.03\\%$ | | test_step_mdp_speed[False-False-True-True-True] | 4.9703ms | 30.9542μs | 32.3058 KOps/s | 32.7407 KOps/s | $\color{#d91a1a}-1.33\\%$ | | test_step_mdp_speed[False-False-True-True-False] | 54.4520μs | 19.6727μs | 50.8317 KOps/s | 49.6940 KOps/s | $\color{#35bf28}+2.29\\%$ | | test_step_mdp_speed[False-False-True-False-True] | 49.0020μs | 18.9595μs | 52.7439 KOps/s | 51.7012 KOps/s | $\color{#35bf28}+2.02\\%$ | | test_step_mdp_speed[False-False-True-False-False] | 38.5020μs | 12.0400μs | 83.0561 KOps/s | 81.1532 KOps/s | $\color{#35bf28}+2.34\\%$ | | test_step_mdp_speed[False-False-False-True-True] | 80.1400μs | 31.2660μs | 31.9836 KOps/s | 31.9307 KOps/s | $\color{#35bf28}+0.17\\%$ | | test_step_mdp_speed[False-False-False-True-False] | 44.5740μs | 20.9597μs | 47.7107 KOps/s | 46.5263 KOps/s | $\color{#35bf28}+2.55\\%$ | | test_step_mdp_speed[False-False-False-False-True] | 51.4460μs | 20.0099μs | 49.9754 KOps/s | 48.7567 KOps/s | $\color{#35bf28}+2.50\\%$ | | test_step_mdp_speed[False-False-False-False-False] | 43.6820μs | 13.2395μs | 75.5316 KOps/s | 73.3613 KOps/s | $\color{#35bf28}+2.96\\%$ | | test_values[generalized_advantage_estimate-True-True] | 11.5159ms | 9.8473ms | 101.5509 Ops/s | 105.3125 Ops/s | $\color{#d91a1a}-3.57\\%$ | | test_values[vec_generalized_advantage_estimate-True-True] | 37.5409ms | 33.6259ms | 29.7390 Ops/s | 27.7375 Ops/s | $\textbf{\color{#35bf28}+7.22\\%}$ | | test_values[td0_return_estimate-False-False] | 0.2253ms | 0.1632ms | 6.1288 KOps/s | 5.9978 KOps/s | $\color{#35bf28}+2.18\\%$ | | test_values[td1_return_estimate-False-False] | 25.7453ms | 24.1383ms | 41.4279 Ops/s | 41.7063 Ops/s | $\color{#d91a1a}-0.67\\%$ | | test_values[vec_td1_return_estimate-False-False] | 35.6314ms | 33.6180ms | 29.7460 Ops/s | 27.6331 Ops/s | $\textbf{\color{#35bf28}+7.65\\%}$ | | test_values[td_lambda_return_estimate-True-False] | 35.6141ms | 34.7343ms | 28.7900 Ops/s | 28.8062 Ops/s | $\color{#d91a1a}-0.06\\%$ | | test_values[vec_td_lambda_return_estimate-True-False] | 35.3133ms | 33.6304ms | 29.7350 Ops/s | 27.5942 Ops/s | $\textbf{\color{#35bf28}+7.76\\%}$ | | test_gae_speed[generalized_advantage_estimate-False-1-512] | 8.6173ms | 8.5050ms | 117.5776 Ops/s | 120.1911 Ops/s | $\color{#d91a1a}-2.17\\%$ | | test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.8700ms | 1.7815ms | 561.3352 Ops/s | 441.4320 Ops/s | $\textbf{\color{#35bf28}+27.16\\%}$ | | test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.5924ms | 0.3565ms | 2.8054 KOps/s | 2.7594 KOps/s | $\color{#35bf28}+1.67\\%$ | | test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 41.2907ms | 38.7780ms | 25.7878 Ops/s | 22.7053 Ops/s | $\textbf{\color{#35bf28}+13.58\\%}$ | | test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 3.9704ms | 3.0346ms | 329.5341 Ops/s | 328.7155 Ops/s | $\color{#35bf28}+0.25\\%$ | | test_dqn_speed | 6.3492ms | 1.2916ms | 774.2185 Ops/s | 763.7180 Ops/s | $\color{#35bf28}+1.37\\%$ | | test_ddpg_speed | 2.9762ms | 2.6971ms | 370.7623 Ops/s | 366.8248 Ops/s | $\color{#35bf28}+1.07\\%$ | | test_sac_speed | 8.9942ms | 7.8987ms | 126.6024 Ops/s | 124.6071 Ops/s | $\color{#35bf28}+1.60\\%$ | | test_redq_speed | 13.5127ms | 12.5030ms | 79.9806 Ops/s | 79.6195 Ops/s | $\color{#35bf28}+0.45\\%$ | | test_redq_deprec_speed | 13.8810ms | 12.5049ms | 79.9688 Ops/s | 78.5457 Ops/s | $\color{#35bf28}+1.81\\%$ | | test_td3_speed | 8.3465ms | 7.8654ms | 127.1398 Ops/s | 126.3317 Ops/s | $\color{#35bf28}+0.64\\%$ | | test_cql_speed | 35.9901ms | 35.1285ms | 28.4669 Ops/s | 28.2031 Ops/s | $\color{#35bf28}+0.94\\%$ | | test_a2c_speed | 7.8295ms | 7.2417ms | 138.0886 Ops/s | 138.1495 Ops/s | $\color{#d91a1a}-0.04\\%$ | | test_ppo_speed | 8.9631ms | 7.4977ms | 133.3734 Ops/s | 132.8584 Ops/s | $\color{#35bf28}+0.39\\%$ | | test_reinforce_speed | 7.6140ms | 6.4028ms | 156.1813 Ops/s | 156.1572 Ops/s | $\color{#35bf28}+0.02\\%$ | | test_iql_speed | 33.3337ms | 31.7687ms | 31.4775 Ops/s | 28.8420 Ops/s | $\textbf{\color{#35bf28}+9.14\\%}$ | | test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.6381ms | 4.7600ms | 210.0833 Ops/s | 208.6635 Ops/s | $\color{#35bf28}+0.68\\%$ | | test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.7460ms | 0.4744ms | 2.1078 KOps/s | 2.0964 KOps/s | $\color{#35bf28}+0.54\\%$ | | test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6694ms | 0.4510ms | 2.2172 KOps/s | 2.1870 KOps/s | $\color{#35bf28}+1.38\\%$ | | test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.6852ms | 4.7144ms | 212.1183 Ops/s | 205.5682 Ops/s | $\color{#35bf28}+3.19\\%$ | | test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.7908ms | 0.4939ms | 2.0247 KOps/s | 2.0859 KOps/s | $\color{#d91a1a}-2.94\\%$ | | test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.6686ms | 0.4568ms | 2.1890 KOps/s | 2.1842 KOps/s | $\color{#35bf28}+0.22\\%$ | | test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 2.4522ms | 1.6914ms | 591.2126 Ops/s | 590.0084 Ops/s | $\color{#35bf28}+0.20\\%$ | | test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 2.1586ms | 1.6032ms | 623.7356 Ops/s | 623.7380 Ops/s | $-0.00\\%$ | | test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.8665ms | 5.0988ms | 196.1264 Ops/s | 204.7254 Ops/s | $\color{#d91a1a}-4.20\\%$ | | test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.0157ms | 0.6216ms | 1.6089 KOps/s | 1.6261 KOps/s | $\color{#d91a1a}-1.06\\%$ | | test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7771ms | 0.5967ms | 1.6759 KOps/s | 1.7160 KOps/s | $\color{#d91a1a}-2.34\\%$ | | test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.2075ms | 4.9130ms | 203.5424 Ops/s | 210.4294 Ops/s | $\color{#d91a1a}-3.27\\%$ | | test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.2371ms | 0.4866ms | 2.0550 KOps/s | 2.0949 KOps/s | $\color{#d91a1a}-1.91\\%$ | | test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6645ms | 0.4617ms | 2.1659 KOps/s | 2.1789 KOps/s | $\color{#d91a1a}-0.60\\%$ | | test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.9631ms | 4.8569ms | 205.8944 Ops/s | 212.9380 Ops/s | $\color{#d91a1a}-3.31\\%$ | | test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.7190ms | 0.4780ms | 2.0923 KOps/s | 2.1093 KOps/s | $\color{#d91a1a}-0.81\\%$ | | test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 7.1771ms | 0.4672ms | 2.1405 KOps/s | 2.2368 KOps/s | $\color{#d91a1a}-4.31\\%$ | | test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 5.2972ms | 4.9849ms | 200.6069 Ops/s | 205.8461 Ops/s | $\color{#d91a1a}-2.55\\%$ | | test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.9501ms | 0.6104ms | 1.6384 KOps/s | 1.6238 KOps/s | $\color{#35bf28}+0.90\\%$ | | test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7855ms | 0.5871ms | 1.7033 KOps/s | 1.6627 KOps/s | $\color{#35bf28}+2.44\\%$ | | test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.1221s | 8.2732ms | 120.8728 Ops/s | 167.5156 Ops/s | $\textbf{\color{#d91a1a}-27.84\\%}$ | | test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 16.8086ms | 12.9537ms | 77.1981 Ops/s | 77.2236 Ops/s | $\color{#d91a1a}-0.03\\%$ | | test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 1.8267ms | 1.1390ms | 877.9705 Ops/s | 903.3859 Ops/s | $\color{#d91a1a}-2.81\\%$ | | test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1099s | 5.8167ms | 171.9196 Ops/s | 171.0075 Ops/s | $\color{#35bf28}+0.53\\%$ | | test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 17.0531ms | 12.9615ms | 77.1517 Ops/s | 77.3634 Ops/s | $\color{#d91a1a}-0.27\\%$ | | test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 1.8565ms | 1.1393ms | 877.7313 Ops/s | 877.4761 Ops/s | $\color{#35bf28}+0.03\\%$ | | test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1100s | 8.1448ms | 122.7777 Ops/s | 123.2273 Ops/s | $\color{#d91a1a}-0.36\\%$ | | test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 17.5152ms | 13.0745ms | 76.4848 Ops/s | 76.7339 Ops/s | $\color{#d91a1a}-0.32\\%$ | | test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 2.0098ms | 1.2830ms | 779.4535 Ops/s | 728.1783 Ops/s | $\textbf{\color{#35bf28}+7.04\\%}$ |
github-actions[bot] commented 2 months ago

$\color{#D29922}\textsf{\Large\⚠\kern{0.2cm}\normalsize Warning}$ Result of GPU Benchmark Tests

Total Benchmarks: 94. Improved: $\large\color{#35bf28}3$. Worsened: $\large\color{#d91a1a}3$.

Expand to view detailed results | Name | Max | Mean | Ops | Ops on Repo `HEAD` | Change | | ----------------------------------------------------------------------------------------- | --------- | --------- | -------------- | ------------------ | ----------------------------------- | | test_single | 0.1083s | 0.1081s | 9.2478 Ops/s | 9.1035 Ops/s | $\color{#35bf28}+1.59\\%$ | | test_sync | 96.2380ms | 95.8970ms | 10.4279 Ops/s | 10.4743 Ops/s | $\color{#d91a1a}-0.44\\%$ | | test_async | 0.1790s | 89.8904ms | 11.1247 Ops/s | 11.2981 Ops/s | $\color{#d91a1a}-1.54\\%$ | | test_single_pixels | 0.1201s | 0.1191s | 8.3950 Ops/s | 8.3807 Ops/s | $\color{#35bf28}+0.17\\%$ | | test_sync_pixels | 77.4057ms | 75.0139ms | 13.3309 Ops/s | 13.3348 Ops/s | $\color{#d91a1a}-0.03\\%$ | | test_async_pixels | 0.1398s | 69.3810ms | 14.4132 Ops/s | 14.4059 Ops/s | $\color{#35bf28}+0.05\\%$ | | test_simple | 0.7820s | 0.7813s | 1.2799 Ops/s | 1.2466 Ops/s | $\color{#35bf28}+2.67\\%$ | | test_transformed | 1.1076s | 1.0335s | 0.9676 Ops/s | 0.9782 Ops/s | $\color{#d91a1a}-1.09\\%$ | | test_serial | 2.3195s | 2.2473s | 0.4450 Ops/s | 0.4416 Ops/s | $\color{#35bf28}+0.76\\%$ | | test_parallel | 2.0396s | 1.9841s | 0.5040 Ops/s | 0.5066 Ops/s | $\color{#d91a1a}-0.51\\%$ | | test_step_mdp_speed[True-True-True-True-True] | 0.1040ms | 36.0206μs | 27.7619 KOps/s | 26.7325 KOps/s | $\color{#35bf28}+3.85\\%$ | | test_step_mdp_speed[True-True-True-True-False] | 39.3010μs | 20.7227μs | 48.2563 KOps/s | 46.5723 KOps/s | $\color{#35bf28}+3.62\\%$ | | test_step_mdp_speed[True-True-True-False-True] | 49.5510μs | 20.6293μs | 48.4748 KOps/s | 47.4369 KOps/s | $\color{#35bf28}+2.19\\%$ | | test_step_mdp_speed[True-True-True-False-False] | 30.9610μs | 11.9492μs | 83.6878 KOps/s | 84.2822 KOps/s | $\color{#d91a1a}-0.71\\%$ | | test_step_mdp_speed[True-True-False-True-True] | 59.5810μs | 38.4785μs | 25.9885 KOps/s | 25.3073 KOps/s | $\color{#35bf28}+2.69\\%$ | | test_step_mdp_speed[True-True-False-True-False] | 42.5110μs | 23.2164μs | 43.0730 KOps/s | 42.0332 KOps/s | $\color{#35bf28}+2.47\\%$ | | test_step_mdp_speed[True-True-False-False-True] | 43.1620μs | 23.0089μs | 43.4614 KOps/s | 41.9601 KOps/s | $\color{#35bf28}+3.58\\%$ | | test_step_mdp_speed[True-True-False-False-False] | 35.4310μs | 14.2273μs | 70.2876 KOps/s | 69.2938 KOps/s | $\color{#35bf28}+1.43\\%$ | | test_step_mdp_speed[True-False-True-True-True] | 71.2520μs | 40.7142μs | 24.5614 KOps/s | 23.7132 KOps/s | $\color{#35bf28}+3.58\\%$ | | test_step_mdp_speed[True-False-True-True-False] | 53.2810μs | 25.5470μs | 39.1435 KOps/s | 38.0121 KOps/s | $\color{#35bf28}+2.98\\%$ | | test_step_mdp_speed[True-False-True-False-True] | 57.1310μs | 22.6112μs | 44.2259 KOps/s | 42.8372 KOps/s | $\color{#35bf28}+3.24\\%$ | | test_step_mdp_speed[True-False-True-False-False] | 32.5900μs | 14.1789μs | 70.5274 KOps/s | 69.9104 KOps/s | $\color{#35bf28}+0.88\\%$ | | test_step_mdp_speed[True-False-False-True-True] | 72.2510μs | 42.4975μs | 23.5308 KOps/s | 22.5736 KOps/s | $\color{#35bf28}+4.24\\%$ | | test_step_mdp_speed[True-False-False-True-False] | 60.1110μs | 27.4178μs | 36.4726 KOps/s | 35.3670 KOps/s | $\color{#35bf28}+3.13\\%$ | | test_step_mdp_speed[True-False-False-False-True] | 44.1910μs | 24.8130μs | 40.3015 KOps/s | 38.6465 KOps/s | $\color{#35bf28}+4.28\\%$ | | test_step_mdp_speed[True-False-False-False-False] | 36.1010μs | 16.3322μs | 61.2289 KOps/s | 61.6458 KOps/s | $\color{#d91a1a}-0.68\\%$ | | test_step_mdp_speed[False-True-True-True-True] | 60.9210μs | 40.4013μs | 24.7517 KOps/s | 23.8815 KOps/s | $\color{#35bf28}+3.64\\%$ | | test_step_mdp_speed[False-True-True-True-False] | 50.5810μs | 25.3500μs | 39.4478 KOps/s | 38.5502 KOps/s | $\color{#35bf28}+2.33\\%$ | | test_step_mdp_speed[False-True-True-False-True] | 52.4410μs | 26.9628μs | 37.0882 KOps/s | 35.9671 KOps/s | $\color{#35bf28}+3.12\\%$ | | test_step_mdp_speed[False-True-True-False-False] | 37.3200μs | 16.0720μs | 62.2200 KOps/s | 60.9309 KOps/s | $\color{#35bf28}+2.12\\%$ | | test_step_mdp_speed[False-True-False-True-True] | 64.3600μs | 42.7812μs | 23.3747 KOps/s | 22.6899 KOps/s | $\color{#35bf28}+3.02\\%$ | | test_step_mdp_speed[False-True-False-True-False] | 45.7710μs | 27.5551μs | 36.2909 KOps/s | 35.5163 KOps/s | $\color{#35bf28}+2.18\\%$ | | test_step_mdp_speed[False-True-False-False-True] | 73.3910μs | 28.8504μs | 34.6615 KOps/s | 33.0408 KOps/s | $\color{#35bf28}+4.91\\%$ | | test_step_mdp_speed[False-True-False-False-False] | 41.6300μs | 18.1248μs | 55.1730 KOps/s | 53.3433 KOps/s | $\color{#35bf28}+3.43\\%$ | | test_step_mdp_speed[False-False-True-True-True] | 4.0032ms | 45.0151μs | 22.2148 KOps/s | 21.4147 KOps/s | $\color{#35bf28}+3.74\\%$ | | test_step_mdp_speed[False-False-True-True-False] | 47.9010μs | 29.7437μs | 33.6205 KOps/s | 32.7864 KOps/s | $\color{#35bf28}+2.54\\%$ | | test_step_mdp_speed[False-False-True-False-True] | 52.3610μs | 29.2682μs | 34.1668 KOps/s | 33.1629 KOps/s | $\color{#35bf28}+3.03\\%$ | | test_step_mdp_speed[False-False-True-False-False] | 49.0700μs | 18.1907μs | 54.9732 KOps/s | 53.8108 KOps/s | $\color{#35bf28}+2.16\\%$ | | test_step_mdp_speed[False-False-False-True-True] | 72.0910μs | 46.9558μs | 21.2966 KOps/s | 20.8103 KOps/s | $\color{#35bf28}+2.34\\%$ | | test_step_mdp_speed[False-False-False-True-False] | 56.4510μs | 31.8132μs | 31.4334 KOps/s | 30.4653 KOps/s | $\color{#35bf28}+3.18\\%$ | | test_step_mdp_speed[False-False-False-False-True] | 59.9610μs | 30.7492μs | 32.5212 KOps/s | 31.3747 KOps/s | $\color{#35bf28}+3.65\\%$ | | test_step_mdp_speed[False-False-False-False-False] | 42.3610μs | 20.2891μs | 49.2875 KOps/s | 48.6610 KOps/s | $\color{#35bf28}+1.29\\%$ | | test_values[generalized_advantage_estimate-True-True] | 24.6960ms | 24.3358ms | 41.0917 Ops/s | 41.0758 Ops/s | $\color{#35bf28}+0.04\\%$ | | test_values[vec_generalized_advantage_estimate-True-True] | 96.6017ms | 2.8266ms | 353.7806 Ops/s | 373.1887 Ops/s | $\textbf{\color{#d91a1a}-5.20\\%}$ | | test_values[td0_return_estimate-False-False] | 92.3010μs | 64.8003μs | 15.4320 KOps/s | 15.3579 KOps/s | $\color{#35bf28}+0.48\\%$ | | test_values[td1_return_estimate-False-False] | 54.9978ms | 54.4949ms | 18.3504 Ops/s | 18.3411 Ops/s | $\color{#35bf28}+0.05\\%$ | | test_values[vec_td1_return_estimate-False-False] | 1.3319ms | 1.0804ms | 925.5448 Ops/s | 926.7632 Ops/s | $\color{#d91a1a}-0.13\\%$ | | test_values[td_lambda_return_estimate-True-False] | 87.4524ms | 86.6773ms | 11.5370 Ops/s | 11.4541 Ops/s | $\color{#35bf28}+0.72\\%$ | | test_values[vec_td_lambda_return_estimate-True-False] | 1.2641ms | 1.0758ms | 929.5132 Ops/s | 924.3211 Ops/s | $\color{#35bf28}+0.56\\%$ | | test_gae_speed[generalized_advantage_estimate-False-1-512] | 24.7570ms | 24.4916ms | 40.8303 Ops/s | 40.7605 Ops/s | $\color{#35bf28}+0.17\\%$ | | test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 0.9360ms | 0.7141ms | 1.4003 KOps/s | 1.4078 KOps/s | $\color{#d91a1a}-0.53\\%$ | | test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.7471ms | 0.6654ms | 1.5029 KOps/s | 1.5003 KOps/s | $\color{#35bf28}+0.18\\%$ | | test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.6187ms | 1.4656ms | 682.3087 Ops/s | 684.6305 Ops/s | $\color{#d91a1a}-0.34\\%$ | | test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.7069ms | 0.6803ms | 1.4700 KOps/s | 1.4683 KOps/s | $\color{#35bf28}+0.12\\%$ | | test_dqn_speed | 7.2455ms | 1.3561ms | 737.4004 Ops/s | 713.1605 Ops/s | $\color{#35bf28}+3.40\\%$ | | test_ddpg_speed | 2.9393ms | 2.7271ms | 366.6902 Ops/s | 362.8527 Ops/s | $\color{#35bf28}+1.06\\%$ | | test_sac_speed | 8.1415ms | 7.8818ms | 126.8739 Ops/s | 125.9786 Ops/s | $\color{#35bf28}+0.71\\%$ | | test_redq_speed | 12.0670ms | 10.1998ms | 98.0416 Ops/s | 99.6837 Ops/s | $\color{#d91a1a}-1.65\\%$ | | test_redq_deprec_speed | 11.2003ms | 10.8145ms | 92.4685 Ops/s | 92.5732 Ops/s | $\color{#d91a1a}-0.11\\%$ | | test_td3_speed | 7.9883ms | 7.8344ms | 127.6424 Ops/s | 127.1104 Ops/s | $\color{#35bf28}+0.42\\%$ | | test_cql_speed | 25.6982ms | 25.1240ms | 39.8026 Ops/s | 39.5185 Ops/s | $\color{#35bf28}+0.72\\%$ | | test_a2c_speed | 5.7846ms | 5.5162ms | 181.2830 Ops/s | 177.4163 Ops/s | $\color{#35bf28}+2.18\\%$ | | test_ppo_speed | 6.0877ms | 5.8146ms | 171.9818 Ops/s | 167.5109 Ops/s | $\color{#35bf28}+2.67\\%$ | | test_reinforce_speed | 5.2625ms | 4.4583ms | 224.3018 Ops/s | 219.7956 Ops/s | $\color{#35bf28}+2.05\\%$ | | test_iql_speed | 20.0205ms | 19.3390ms | 51.7090 Ops/s | 50.7365 Ops/s | $\color{#35bf28}+1.92\\%$ | | test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.8335ms | 6.6412ms | 150.5745 Ops/s | 147.2232 Ops/s | $\color{#35bf28}+2.28\\%$ | | test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.6549ms | 0.5261ms | 1.9007 KOps/s | 1.9035 KOps/s | $\color{#d91a1a}-0.14\\%$ | | test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6929ms | 0.5049ms | 1.9804 KOps/s | 2.0010 KOps/s | $\color{#d91a1a}-1.03\\%$ | | test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.8353ms | 6.5292ms | 153.1573 Ops/s | 151.2057 Ops/s | $\color{#35bf28}+1.29\\%$ | | test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.4001ms | 0.5122ms | 1.9522 KOps/s | 1.9352 KOps/s | $\color{#35bf28}+0.88\\%$ | | test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.6761ms | 0.4945ms | 2.0221 KOps/s | 2.0211 KOps/s | $\color{#35bf28}+0.05\\%$ | | test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 2.1617ms | 1.9770ms | 505.8102 Ops/s | 501.6605 Ops/s | $\color{#35bf28}+0.83\\%$ | | test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 2.0618ms | 1.8767ms | 532.8376 Ops/s | 529.1907 Ops/s | $\color{#35bf28}+0.69\\%$ | | test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.9180ms | 6.7587ms | 147.9576 Ops/s | 146.7965 Ops/s | $\color{#35bf28}+0.79\\%$ | | test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.9750ms | 0.6724ms | 1.4872 KOps/s | 1.4719 KOps/s | $\color{#35bf28}+1.04\\%$ | | test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8180ms | 0.6511ms | 1.5359 KOps/s | 1.5323 KOps/s | $\color{#35bf28}+0.23\\%$ | | test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.7845ms | 6.6517ms | 150.3382 Ops/s | 146.4199 Ops/s | $\color{#35bf28}+2.68\\%$ | | test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.9348ms | 0.5272ms | 1.8967 KOps/s | 1.9114 KOps/s | $\color{#d91a1a}-0.77\\%$ | | test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6551ms | 0.5056ms | 1.9780 KOps/s | 2.0061 KOps/s | $\color{#d91a1a}-1.40\\%$ | | test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.9231ms | 6.5829ms | 151.9097 Ops/s | 149.5321 Ops/s | $\color{#35bf28}+1.59\\%$ | | test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.7462ms | 0.5389ms | 1.8557 KOps/s | 1.9251 KOps/s | $\color{#d91a1a}-3.61\\%$ | | test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 1.4684ms | 0.5761ms | 1.7359 KOps/s | 2.0065 KOps/s | $\textbf{\color{#d91a1a}-13.49\\%}$ | | test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.9649ms | 6.8082ms | 146.8828 Ops/s | 144.8607 Ops/s | $\color{#35bf28}+1.40\\%$ | | test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.9442ms | 0.6847ms | 1.4605 KOps/s | 1.4693 KOps/s | $\color{#d91a1a}-0.60\\%$ | | test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7870ms | 0.6504ms | 1.5375 KOps/s | 1.5177 KOps/s | $\color{#35bf28}+1.31\\%$ | | test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.1318s | 7.6483ms | 130.7474 Ops/s | 100.2972 Ops/s | $\textbf{\color{#35bf28}+30.36\\%}$ | | test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 19.6142ms | 15.8788ms | 62.9771 Ops/s | 60.9847 Ops/s | $\color{#35bf28}+3.27\\%$ | | test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 2.4495ms | 1.2261ms | 815.6235 Ops/s | 732.7017 Ops/s | $\textbf{\color{#35bf28}+11.32\\%}$ | | test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1241s | 7.4996ms | 133.3407 Ops/s | 132.1132 Ops/s | $\color{#35bf28}+0.93\\%$ | | test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 18.5292ms | 15.8524ms | 63.0821 Ops/s | 61.3859 Ops/s | $\color{#35bf28}+2.76\\%$ | | test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 2.4271ms | 1.1931ms | 838.1335 Ops/s | 849.0529 Ops/s | $\color{#d91a1a}-1.29\\%$ | | test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1275s | 10.1190ms | 98.8238 Ops/s | 129.7365 Ops/s | $\textbf{\color{#d91a1a}-23.83\\%}$ | | test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 18.7779ms | 16.2217ms | 61.6456 Ops/s | 61.0035 Ops/s | $\color{#35bf28}+1.05\\%$ | | test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 2.1934ms | 1.3140ms | 761.0181 Ops/s | 670.2804 Ops/s | $\textbf{\color{#35bf28}+13.54\\%}$ |
matteobettini commented 2 months ago

Could you explain why we need this?

also, having 2 copies of the parameters is not error prone?

for example in methods like https://github.com/facebookresearch/BenchMARL/blob/d260eea5d4ef2ff5f0bea8ae36f68638ecb14865/benchmarl/models/common.py#L165 or in any general case where users access self.parameters() won’t things break?

vmoens commented 2 months ago

We test that nothing breaks. I don't thing it's error prone, you never see two copies (for instance parameters() just returns one).

We need this because it makes initialization of the params more natural, mainly.

matteobettini commented 2 months ago

So if a user modifies the content of one copy of the parameters, the change is reflected in the other copy? As in the function I sent.

But apart from being more natural, what use cases is it used for/ envisioned for?

matteobettini commented 2 months ago

Maybe I am misreading the PR description: when you say 2 copies you mean:

  1. 2 different sets of parameters?
  2. 2 objects referring by pointer to the same parameter tensors?
vmoens commented 2 months ago

So if a user modifies the content of one copy of the parameters, the change is reflected in the other copy? As in the function I sent.

They are exactly the same objects, just one is in self.params and not seen by self.modules() or self.parameters() and the other is in self._empty_net.

But apart from being more natural, what use cases is it used for/ envisioned for?

Many people are used to do

def init(module):
    if isinstance(module, nn.Linear):
        self.weight.data.zero_()

self.apply(init)

which you can only do if the params are in the module, not in the TDParams. Moreover TDParams carries some overhead. The new version should be faster. On top of that it's totally optional and 100% non-bc breaking