Closed skandermoalla closed 8 months ago
Thanks, that's very thorough.
We could definitely make a better job on CPU. Things should get better on cuda, let me try to reproduce.
If it's of interest, you can also try with the Multi(a)SyncDataCollector which is comparatively faster compared with ParallelEnv. ParallelEnv is especially useful when you have large models that you want to execute over a batch of envs, but for simple envs and simple models dispatching each on a separate collector usually works faster.
I'll keep you posted!
Yes, it's a fair point to mention that this doesn't include the overhead of moving data between CPU and GPU, maybe the TorchRL classes are better suited for that. I'll wait for the benchmark on GPU!
There is certainly room for improvement, but there will always be the cost of building tensordicts out of the gym envs. It's part of the price to pay to have envs that you can recycle across gym, dm_control, brax and many others, with or without rendering, with or without transforms: formatting the data to a common structure comes at a certain cost.
We're quite competitive when it comes to executing more demanding environments, e.g. the ones that require rendering, and when there are transforms to be executed.
For the simplest pendulum or cartpole it's hard to do better than gym...
There's some more improvement to be achieved but I managed to reduce the compute time by 1/3-1/2 approx for CartPole rollouts with the 2 PRs linked above. Before (on 1, 4 env)
0.6000917710000007 0.8977178509999995 0.9628986140000002
2.3650796359999973 2.1196543889999973 1.398552716999987
After
0.3220986410000002 0.5811318410000013 0.567181644999998
1.4485275350000038 1.648879229000002 0.9843755919999992
We're far from the 10x improvement we could have wished for but it is some progress already!
On GPU, the perf is a bit less impressive. Here's some benchmark on GPU with Pong-v5: Before:
1.1242192424833775 1.490046987310052 1.7628506254404783
4.569527314975858 3.7096991054713726 2.416346298530698
After:
0.8273551240563393 1.2029384840279818 1.3915523868054152
3.3085566088557243 3.4223811104893684 2.093131694942713
From what I can see, bottlenecks are mostly tensordict related (cloning, set, get, checking values etc).
Be assured that we'll work on making these more efficient! As usual: any help is welcome :)
In the paper the speed results (still to copy here) are achieved via MultiaSyncDataCollector, not ParallelEnv. For most off-policy algos this is a valid way of collecting data. For on-policy, MultiSync can be used.
After some more improvements, CartPole looks like this:
0.28064374699999917 0.5035237049999992 0.7829816540000003
1.275079895999994 1.4358267960000006 0.8376919339999986
2.4235935930000068 2.9860773900000055 1.2749536729999988
On my machine, this compares with gym like this:
0.029683207571428078 0.0680098183333347 0.11986857750000013
0.05640612449999871 0.0676216029999992 0.20421299000000204
0.11158799900000815 0.10956462599999384 0.5347069550000043
So the parallel version with 8 workers is "only" twice as slow.
Have a look at the benchmarks too
https://pytorch-labs.github.io/tensordict/dev/bench/ https://pytorch.org/rl/dev/bench/
I'm impressed by the responsiveness! Thanks a lot!
I also reran the benchmark on my side on an M1 chip at the commit @ea6f872 and got the results below.
The major points I found were:
This progress is already awesome, so feel free to close the issue!
num_workers_1 num_workers_4 num_workers_8
Single, s 0.055 0.223 0.472
Serial, s 0.104 0.281 0.676
Parallel, s 0.113 0.218 0.399
MultiSync, s 0.075 0.079 0.165
Single_gym, s 0.004 0.015 0.031
Serial_gym, s 0.007 0.016 0.024
Parallel_gym, s 0.025 0.073 0.103
relative_time parallel/single, % 205.500 97.800 84.500
relative_time parallel/serial, % 108.700 77.600 59.000
relative_time multisync/parallel, % 66.400 36.200 41.400
relative_time single_gym/single, % 7.300 6.700 6.600
relative_time serial_gym/serial, % 6.700 5.700 3.600
relative_time parallel_gym/parallel, % 22.100 33.500 25.800
There's still room for improvement so I'll leave it open as a reminder that we should aim at getting the same speed as gym even for simple envs! Glad to read that you observed the same improvements
There have been multiple iterations on this so I'm closing the issue for now. If we observe a regression we can reopen
Describe the bug
I quickly adapted the benchmark for batched environments to compare against native gymnasium environment classes and got drastically worse performance on CPU.
Single envs and SerialEnvs are up to 50x slower and ParallelEnv is up to 10x slower 😬
Could anyone look into this?
To Reproduce
output on macOS with an Apple M1.
Output on Ubuntu with an Intel(R) Xeon(R) Gold 6240
System info
Describe the characteristic of your environment:
macOS with M1
Ubuntu with Intel(R) Xeon(R) Gold 6240
Checklist