Closed jan1854 closed 8 months ago
Hello, thanks a lot for the PR =) I'll try to look into it today.
I assume you also check that the performance was the same with this change?
EDIT: I started a report here: https://wandb.ai/openrlbenchmark/sbx/reports/SBX-v0-9-1-unroll-vs-PR-21-for-i-loop---Vmlldzo2MjUwNDIy
btw, as you use master branch of your fork to do the PR, I won't be able to push any edits :/ (protected by default)
I will be off until January, will continue testing/reviewing at that time ;)
I assume you also check that the performance was the same with this change?
I benchmarked the performance of the three algorithms (after the bugfix) on Walker2D-v4 for 10 random seeds. I used the default hyperparameter (the train_freq=1
case of the computation time benchmark above). Here are the results (mean and standard deviation across the seeds).
In all cases, the average performance is roughly the same for the original implementation and the PR. The standard deviation differs between the implementations for SAC and TD3 but I assume this is due to noise and that the difference would vanish for more seeds.
Hello,
I'm back =)
thanks for the benchmark, I will try to do one additional soon with train_freq=gradient_steps=12
.
On-going report is here: https://wandb.ai/openrlbenchmark/sbx/reports/SBX-TD3-PR-21-perf-check--Vmlldzo2NDk2NDUy
Performance-wise, it looks good. The compilation is much faster, however I observed a slower training after compilation compared to the current SBX version (we should probably play with loop unroll parameter that was introduced in newer jax version, but #22 must be fixed first).
Sorry for the late reply. I am a bit busy with deadlines at the moment, but I try to have a look at the loop unrolling next week.
After some local tests:
Tested jax versions:
jaxlib==0.4.12+cuda11.cudnn86
(work with unroll but slower when using for i loo) jaxlib==0.4.23+cuda11.cudnn86
(NaN with unroll, faster than master when using for i loop, couldn't test master though)I would be happy if you could help isolate the NaN bug so we can send a minimal working example to Jax team.
After some local tests:
- there seems to be a bug with newer jax versions (I get NaN when I try to unroll on CUDA)
- the performance drop also depends on jax/cuda version, the compilation time too (latest jax version, 0.4.23 is slower to compile, yield NaNs when trying to unroll but is the fastest, no perf drop w.r.t. master)
Tested jax versions:
jaxlib==0.4.12+cuda11.cudnn86
(work with unroll but slower when using for i loo)jaxlib==0.4.23+cuda11.cudnn86
(NaN with unroll, faster than master when using for i loop, couldn't test master though)
Hi @araffin, I was able to reproduce the NaN problem on my machine for jaxlib==0.4.23+cuda11.cudnn86
. However, I noticed that the problem vanishes for jaxlib version 0.4.24. It also seems to be independent of whether the CUDA 11 or CUDA 12 version is used (works for both jaxlib==0.4.24+cuda11.cudnn86
or jaxlib==0.4.24+cuda12.cudnn89
). So, I guess that the Jax people fixed the issue in the latest jaxlib version, but I could not find anything about it in the changelogs.
I will do some benchmarking of the performance with different values for the unroll parameter to figure out if it helps. I'll come back to you as soon as I have the data.
Thanks, I can confirm that the bug is gone with newest version of jax.
After some testing, it also seems that keeping the default unroll
value is the best for speed.
I can confirm that observation. In my tests, the runs without loop unrolling were almost always faster (in terms of the time/fps
metric). I also compared against the old implementation from before the PR. Previously, you reported that time/fps
was slightly higher for the old implementation. In my experiments, this does not seem to be the case anymore. The new implementation was always faster, so perhaps the new version of Jax also improved the efficiency of the fori_loop
construct in the meantime.
The results seem to be consistent across all three algorithms.
Previously, you reported that time/fps was slightly higher for the old implementation. In my experiments, this does not seem to be the case anymore.
yes, that was the case for older version of jax. Not the case anymore with newer versions.
Thanks for checking =)
Description
closes #14
Addresses #14 by replacing the python loop in
_train()
with ajax.lax.fori_loop
construct for SAC, TD3, and TQC. This avoids explicitly unrolling the loop during compile time, which significantly reduces compile times and the size of the compiled program for largegradient_steps
values.Since performance was a concern in https://github.com/araffin/sbx/issues/14#issuecomment-1687684993, I did some benchmarking. To not exacerbate the influence of the compile time, I trained each agent for 1 million steps. I used different values for
train_freq
(withgradient_steps=-1
) to highlight the differences between the implementations. I used the following script for benchmarking.I got the following results on Ubuntu 22.04 with an NVIDIA RTX A6000 GPU.
SAC:
TD3:
TQC:
For small
gradient_steps
values, the implementation is on par with the original implementation for all three algorithms. For large values, it leads to significant speed-ups due to faster compilation and smaller program sizes. I also tried running the original implementation withgradient_steps=1000
but it ran out-of-memory during compilation on my 64GB RAM machine.Motivation and Context
Types of changes
Checklist:
docs/misc/changelog.rst
)make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)make doc
(required)Note: You can run most of the checks using
make commit-checks
.Note: we are using a maximum length of 127 characters per line