araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)
MIT License
328 stars 32 forks source link

Added support for large values for gradient_steps to SAC, TD3, and TQC #21

Closed jan1854 closed 8 months ago

jan1854 commented 9 months ago

Description

closes #14

Addresses #14 by replacing the python loop in _train() with a jax.lax.fori_loop construct for SAC, TD3, and TQC. This avoids explicitly unrolling the loop during compile time, which significantly reduces compile times and the size of the compiled program for large gradient_steps values.

Since performance was a concern in https://github.com/araffin/sbx/issues/14#issuecomment-1687684993, I did some benchmarking. To not exacerbate the influence of the compile time, I trained each agent for 1 million steps. I used different values for train_freq (with gradient_steps=-1) to highlight the differences between the implementations. I used the following script for benchmarking.

import argparse
import time

from sbx import SAC, TD3, TQC

parser = argparse.ArgumentParser()
parser.add_argument("algo", type=str)
parser.add_argument("train_freq", type=int)
parser.add_argument("--steps", type=int, default=1_000_000)
args = parser.parse_args()

algo_class = eval(args.algo.upper())

start = time.time()
algo = algo_class(
    "MlpPolicy",
    "Walker2d-v4",
    train_freq=args.train_freq,
    gradient_steps=-1,
    tensorboard_log=f"tensorboard/{args.algo}_{args.train_freq}_{args.steps}",
)
algo.learn(args.steps)
print(f"{args.algo.upper()}, train_freq: {args.train_freq}, steps: {args.steps}: {time.time() - start:.2f}s")

I got the following results on Ubuntu 22.04 with an NVIDIA RTX A6000 GPU.

SAC:

train_freq and gradient_steps Execution time in seconds (original) Execution time in seconds (PR) Relative speedup (torig / tpr)
1 2716 2713 100.1 %
4 1621 1659 97.7 %
16 1498 1490 100.5 %
64 1784 1360 130.2 %
128 2587 1256 206.0 %

TD3:

train_freq and gradient_steps Execution time in seconds (original) Execution time in seconds (PR) Relative speedup (torig / tpr)
1 2281 2270 100.5 %
4 1315 1404 93.7 %
16 1094 1070 102.2 %
64 1103 950 116.1 %
128 1311 939 139.6 %

TQC:

train_freq and gradient_steps Execution time in seconds (original) Execution time in seconds (PR) Relative speedup (torig / tpr)
1 2995 2936 102.0 %
4 2070 1926 107.5 %
16 2035 1721 118.2 %
64 2090 1443 144.8 %
128 3359 1418 236.9 %

For small gradient_steps values, the implementation is on par with the original implementation for all three algorithms. For large values, it leads to significant speed-ups due to faster compilation and smaller program sizes. I also tried running the original implementation with gradient_steps=1000 but it ran out-of-memory during compilation on my 64GB RAM machine.

Motivation and Context

Types of changes

Checklist:

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

araffin commented 9 months ago

Hello, thanks a lot for the PR =) I'll try to look into it today.

I assume you also check that the performance was the same with this change?

EDIT: I started a report here: https://wandb.ai/openrlbenchmark/sbx/reports/SBX-v0-9-1-unroll-vs-PR-21-for-i-loop---Vmlldzo2MjUwNDIy

araffin commented 9 months ago

btw, as you use master branch of your fork to do the PR, I won't be able to push any edits :/ (protected by default)

araffin commented 9 months ago

I will be off until January, will continue testing/reviewing at that time ;)

jan1854 commented 9 months ago

I assume you also check that the performance was the same with this change?

I benchmarked the performance of the three algorithms (after the bugfix) on Walker2D-v4 for 10 random seeds. I used the default hyperparameter (the train_freq=1 case of the computation time benchmark above). Here are the results (mean and standard deviation across the seeds).

In all cases, the average performance is roughly the same for the original implementation and the PR. The standard deviation differs between the implementations for SAC and TD3 but I assume this is due to noise and that the difference would vanish for more seeds.

araffin commented 9 months ago

Hello, I'm back =) thanks for the benchmark, I will try to do one additional soon with train_freq=gradient_steps=12.

araffin commented 8 months ago

On-going report is here: https://wandb.ai/openrlbenchmark/sbx/reports/SBX-TD3-PR-21-perf-check--Vmlldzo2NDk2NDUy

Performance-wise, it looks good. The compilation is much faster, however I observed a slower training after compilation compared to the current SBX version (we should probably play with loop unroll parameter that was introduced in newer jax version, but #22 must be fixed first).

jan1854 commented 8 months ago

Sorry for the late reply. I am a bit busy with deadlines at the moment, but I try to have a look at the loop unrolling next week.

araffin commented 8 months ago

After some local tests:

Tested jax versions:

I would be happy if you could help isolate the NaN bug so we can send a minimal working example to Jax team.

jan1854 commented 7 months ago

After some local tests:

  • there seems to be a bug with newer jax versions (I get NaN when I try to unroll on CUDA)
  • the performance drop also depends on jax/cuda version, the compilation time too (latest jax version, 0.4.23 is slower to compile, yield NaNs when trying to unroll but is the fastest, no perf drop w.r.t. master)

Tested jax versions:

  • jaxlib==0.4.12+cuda11.cudnn86 (work with unroll but slower when using for i loo)
  • jaxlib==0.4.23+cuda11.cudnn86 (NaN with unroll, faster than master when using for i loop, couldn't test master though)

Hi @araffin, I was able to reproduce the NaN problem on my machine for jaxlib==0.4.23+cuda11.cudnn86. However, I noticed that the problem vanishes for jaxlib version 0.4.24. It also seems to be independent of whether the CUDA 11 or CUDA 12 version is used (works for both jaxlib==0.4.24+cuda11.cudnn86 or jaxlib==0.4.24+cuda12.cudnn89). So, I guess that the Jax people fixed the issue in the latest jaxlib version, but I could not find anything about it in the changelogs.

I will do some benchmarking of the performance with different values for the unroll parameter to figure out if it helps. I'll come back to you as soon as I have the data.

araffin commented 7 months ago

Thanks, I can confirm that the bug is gone with newest version of jax. After some testing, it also seems that keeping the default unroll value is the best for speed.

jan1854 commented 7 months ago

I can confirm that observation. In my tests, the runs without loop unrolling were almost always faster (in terms of the time/fps metric). I also compared against the old implementation from before the PR. Previously, you reported that time/fps was slightly higher for the old implementation. In my experiments, this does not seem to be the case anymore. The new implementation was always faster, so perhaps the new version of Jax also improved the efficiency of the fori_loop construct in the meantime.

The results seem to be consistent across all three algorithms.

araffin commented 7 months ago

Previously, you reported that time/fps was slightly higher for the old implementation. In my experiments, this does not seem to be the case anymore.

yes, that was the case for older version of jax. Not the case anymore with newer versions.

Thanks for checking =)