DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.95k stars 1.69k forks source link

[Bug:] Cannot use the `fused` flag in default optimizer of PPO #1770

Closed cmangla closed 10 months ago

cmangla commented 10 months ago

The default Adam optimizer has a fused flag, which, according to the docs, is significantly faster than the default when used on CUDA. Using it with PPO generates an exception, which complains that the parameters are not of type CUDA.

The fused parameter can be specified to PPO using policy_kwargs = dict(optimizer_kwargs={'fused': True}). But, the issue is in the following lines of code: https://github.com/DLR-RM/stable-baselines3/blob/c8fda060d4bcb283eea6ddd385be5a46a54d3356/stable_baselines3/common/on_policy_algorithm.py#L133-L136

Before line 133 above, the correct device has been initialised in self.device. But, the policy_class is initialized without it in line 133, so it initialises with the cpu device, and that also initialises the optimizer with the cpu device. In line 136, the device of the policy_class is updated to the correct one, but by then it is too late, because the optimizer had already been initialized, and it thought the device was cpu.

This is a problem with the fused flag, because the Adam optimiser does check it and then double-checks self.parameters() to ensure they are of the correct type, and complains, in my case, that it is not of cuda type.

If the policy_class in line 133 above was passed the correct device (i.e. self.device) in the initialization in the first place, it could set it correctly before MlpExtractor gets initialized. MlpExtractor gets initialized to the parent class's device in the lines below: https://github.com/DLR-RM/stable-baselines3/blob/c8fda060d4bcb283eea6ddd385be5a46a54d3356/stable_baselines3/common/policies.py#L568-L581

Here is the traceback I get:

Traceback (most recent call last):
  File "x.py", line 350, in <module>
    main(sys.argv)
  File "x.py", line 254, in main
    model = PPO(
            ^^^^
  File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/ppo/ppo.py", line 164, in __init__
    self._setup_model()
  File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/ppo/ppo.py", line 167, in _setup_model
    super()._setup_model()
  File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 123, in _setup_model
    self.policy = self.policy_class(  # type: ignore[assignment]
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/common/policies.py", line 857, in __init__
    super().__init__(
  File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/common/policies.py", line 507, in __init__
    self._build(lr_schedule)
  File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/common/policies.py", line 610, in _build
    self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/z/venv/lib64/python3.11/site-packages/torch/optim/adam.py", line 60, in __init__
    raise RuntimeError("`fused=True` requires all the params to be floating point Tensors of "
RuntimeError: `fused=True` requires all the params to be floating point Tensors of supported devices: ['cuda', 'xpu', 'privateuseone'].
araffin commented 10 months ago

Hello,

has a fused flag, which, according to the docs, is significantly faster than the default when used on CUDA.

Could you please share some runs/runtime on standard environments? (to have a better idea of the potential gain)

I would recommend you to use the RL Zoo and log things using W&B. (or do runs with hyperfine)

Right now, the solution to your problem is to define a custom policy or to fork SB3. We use self.policy = self.policy.to(self.device) because that the cleanest to not forget any parameter, we used to pass a device parameter in the past but removed it.

EDIT: if you want a significant runtime boost, you can have a look at https://github.com/araffin/sbx

cmangla commented 10 months ago

Could you please share some runs/runtime on standard environments? (to have a better idea of the potential gain)

To do so, I would have to fix this issue locally and test on all the standard environments. I don't have the capacity for that just now, unfortunately.

I did test this locally for my use case and I noticed a small improvement in very short runs, but not enough for me to spend much more time in fixing this issue upstream.

The PyTorch docs (here) also state a speed-up, but don't quantify it.

cmangla commented 10 months ago

@araffin I have a draft PR #1771 that fixes this and is backwards compatible.

araffin commented 10 months ago

Hello, thanks for the PR.

The benchmark is still needed to know if it's something that should be done on other algorithms or not, or if it should only be mentioned in the doc (with a link to your fork).

Btw, as I mentioned before, if you want a real performance boost, you can have a look at https://github.com/araffin/sbx (usually faster when run on cpu only when not using CNN, see https://arxiv.org/abs/2310.05808)

cmangla commented 10 months ago

The benchmark is still needed to know if it's something that should be done on other algorithms or not, or if it should only be mentioned in the doc (with a link to your fork).

Actually my PR is generic to anything that inherits from OnPolicyAlgorithms. It doesn't actually set the fused flag, it just makes it possible to be set with the existing SB3 API. That is why I have submitted it as a generic bug fix.

I will update this issue and the PR with benchmark results when I am able to run them.

Btw, as I mentioned before, if you want a real performance boost, you can have a look at https://github.com/araffin/sbx (usually faster when run on cpu only when not using CNN, see https://arxiv.org/abs/2310.05808)

Thanks for the suggestion. I will consider sbx in the future, but for my current project I'll have to stick with SB3.

cmangla commented 10 months ago

@araffin The current github-actions based CI will skip our GPU tests, but it looks like GPU runners are in beta, starting last month. Just pointing it out since there is a waiting-list to signup to. See: https://github.com/github/roadmap/issues/505 https://resources.github.com/devops/accelerate-your-cicd-with-arm-and-gpu-runners-in-github-actions/

cmangla commented 10 months ago

I did an initial test running the following with and without the fused optimizer argument:

python -m rl_zoo3.train --algo ppo --env BipedalWalker-v3 --conf-file venv/lib/python3.11/site-packages/rl_zoo3/hyperparams/ppo.yml

Without fused, I got: real 37m50.874s user 36m57.949s sys 0m14.338s

With fused, I got: real 35m59.855s user 34m56.937s sys 0m13.729

This was a single run. It shows a 5.5% reduction in runtime. But this is with a single run, so I don't know of the noise yet. The GPU is RTX A5000 and CPU is AMD Ryzen 9 7900X with hyperthreading disabled.

araffin commented 10 months ago

Hello, thanks for the quick try =) I did the following runs:

hyperfine -m 2 "python -m rl_zoo3.train --algo ppo --env BipedalWalker-v3 --eval-freq -1 -n 200000 -P"

CPU only:

hyperfine -m 2 "CUDA_VISIBLE_DEVICES= python -m rl_zoo3.train --algo ppo --env BipedalWalker-v3 --eval-freq -1 -n 200000 -P"
Time (mean ± σ):     142.949 s ±  0.273 s 

One cpu:

hyperfine -m 2 "CUDA_VISIBLE_DEVICES= OMP_NUM_THREADS=1 python -m rl_zoo3.train --algo ppo --env BipedalWalker-v3 --eval-freq -1 -n 200000 -P"
Time (mean ± σ):     135.304 s ±  0.991 s 

GPU, no fused:

Time (mean ± σ):     169.488 s ±  0.387 s

GPU, fused:

hyperfine -m 2 "python -m rl_zoo3.train --algo ppo --env BipedalWalker-v3 --eval-freq -1 -n 200000 -P  -params policy_kwargs:'dict(optimizer_kwargs=dict(fused=True))'"
Time (mean ± σ):     165.571 s ±  0.712 s 

SBX CPU (still using RL Zoo: https://rl-baselines3-zoo.readthedocs.io/en/master/guide/sbx.html):

Time (mean ± σ):     94.541 s ±  4.244 s

SBX GPU:

Time (mean ± σ):     114.366 s ±  0.240 s

As I wrote before, when not using CNN, PPO on CPU is usually the fastest (20-25% faster). So far, the performance gain with the fused optimizer is not really significant (only 2% faster). I'm going to do some run on Atari games to check.

For now:

  1. SBX CPU (1x, 94s)
  2. SBX GPU (1.2x)
  3. SB3 CPU (1.4x 1 CPU, 1.5x multi-threading)
  4. SB3 GPU fused (1.75x)
  5. SB3 GPU (1.8x)
araffin commented 10 months ago

For Atari games, with CNN

hyperfine -m 2 "python -m rl_zoo3.train --algo ppo --env PongNoFrameskip-v4 --eval-freq -1 -n 50000 -P "

CPU: Time (mean ± σ): 93.623 s ± 0.901 s GPU: Time (mean ± σ): 55.244 s ± 1.622 s GPU fused: Time (mean ± σ): 53.143 s ± 0.393 s

cmangla commented 10 months ago

I did an initial test running the following with and without the fused optimizer argument:

python -m rl_zoo3.train --algo ppo --env BipedalWalker-v3 --conf-file venv/lib/python3.11/site-packages/rl_zoo3/hyperparams/ppo.yml

I ran this again 10 times over, with and without fused. The fused version only reduces the average elapsed time from 36:55 to 36:35, so that is only a 0.9% saving.

It is still worth merging #1771, which doesn't introduce the fused flag, but makes it possible to use it if anyone wishes to. It also ensures the optimizers don't see parameter types changing underneath them, between initialization and first use.

cmangla commented 10 months ago

For Atari games, with CNN

hyperfine -m 2 "python -m rl_zoo3.train --algo ppo --env PongNoFrameskip-v4 --eval-freq -1 -n 50000 -P "

I ran the same on my machine with '-m 5' and I got:

Command Mean [s] Min [s] Max [s]
python -m rl_zoo3.train --algo ppo --env PongNoFrameskip-v4 --eval-freq -1 -n 50000 -P -params policy_kwargs:'dict(optimizer_kwargs=dict(fused=False))' 34.798 ± 0.503 34.149 35.211
python -m rl_zoo3.train --algo ppo --env PongNoFrameskip-v4 --eval-freq -1 -n 50000 -P -params policy_kwargs:'dict(optimizer_kwargs=dict(fused=True))' 35.491 ± 0.497 34.875 36.112
python -m rl_zoo3.train --algo ppo --env PongNoFrameskip-v4 --eval-freq -1 -n 50000 -P --device cpu 64.738 ± 2.490 62.592 68.854

So in this case I actually experience a slowdown with fused, but the GPU runs are much faster than the CPU.

cmangla commented 10 months ago

As I wrote before, when not using CNN, PPO on CPU is usually the fastest (20-25% faster).

I should mention that although this is true even for my workload, that a single training instance goes nearly as fast on CPUs (20 of them) as it does on the GPU, it is not the case when I do hyperparameter tuning with many parallel instances of rl_zoo3 (in separate process, not using optuna threads).

It seems to me that when the GPU use is enabled, the CPU usage of a single training instance is limited to a single core and the GPU is partially utilized. I am then able to run enough instances in parallel to max out my hardware. In contract, training a single instance on only the CPU seems to max out usage on all CPUs, so then it is not beneficial to run hyperparameter tuning in parallel.

The difference between the two setups above, for me, is enormous. For me it is the difference between getting nothing at all versus getting some results after a weekend run of rl_zoo3 on my environment. That is why I still use the GPU with PPO.

araffin commented 10 months ago

i think you should take a look at the run i did with a single cpu (to disable inter op parallelism) and related issues (search "num threads pytorch").

I appreciate the PR you did, but the current results don't justify the change. This would also introduce inconsistency between on/off policy and all algorithms in sb3 contrib would have to be adjusted too.

cmangla commented 10 months ago

i think you should take a look at the run i did with a single cpu (to disable inter op parallelism) and related issues (search "num threads pytorch").

Thanks for the tip, I will look into it.

I appreciate the PR you did, but the current results don't justify the change. This would also introduce inconsistency between on/off policy and all algorithms in sb3 contrib would have to be adjusted too.

Ok, no problem. Feel free to close this issue and the PR. Thanks for your inputs on this matter.

araffin commented 10 months ago

Closing as not planned for now, will re-open in case of new results/other cases that justify the change.