luchris429 / purejaxrl

Really Fast End-to-End Jax RL Implementations
Apache License 2.0
615 stars 53 forks source link

S5: Longer compilation times #25

Open stergiosba opened 1 month ago

stergiosba commented 1 month ago

Hey, thanks for providing purejaxrl is pretty awesome.

I have used the experimental S5 code that you provide for a part of my research and after version 0.4.27 (same for 0.4.28) of jaxlib I have been getting 5 times longer compilation times when I increase the n_layers of the S5. Any ideas why this might happen?

luchris429 commented 1 month ago

Interesting! I'm not sure why that would happen. However, I do think the flax RNN documentation/structure may have changed after version 0.4.27, which could be why you're getting the error.

https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/rnncell_upgrade_guide.html

stergiosba commented 1 month ago

Well there is no error technically speaking. The code runs fine and the model trains fine. The only problem is that compilation on the XLA side of things takes longer. Version 7.4 of XLA (Jax 0.4.26) is much faster than 8.3 (Jax 0.4.28) in generating the code for a GPU device (hope this add clarity to the issue). Also I face this issue in two different GPUs on two different machines.

Anyways, maybe this will alert you on your future JAX endeavors. Thanks again.

stergiosba commented 3 days ago

I ended up fixing this by making the StackedEncoderModel a scanned version of what you initially had. There are some minimal code changes for the end user which maybe we can fix.

Some compilation benchmarks:

  1. 1 S5 Layer: 28 sec (old) vs 28 sec(new)
  2. 4 S5 Layers: 90 sec (old) vs 29 sec (new)
  3. 20 S5 Layers: 29 sec (new)
  4. 200 S5 Layers 29 sec (new) - crazy case just for test

I also attach the results from setting jax.config.update("jax_log_compiles", True)

20 layers S5:

Finished jaxpr to MLIR module conversion jit(train) in 1.3401036262512207 sec
Finished XLA compilation of jit(train) in 29.01173186302185 sec vs 
200 layers S5:

Finished jaxpr to MLIR module conversion jit(train) in 1.3581228256225586 sec
2024-06-29 18:11:22.686136: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 8.90GiB (9555618931 bytes) by rematerialization; only reduced to 10.13GiB (10874848652 bytes), down from 10.13GiB (10875108940 bytes) originally
Finished XLA compilation of jit(train) in 29.203452587127686 sec

Tests where run on a single NVIDIA RTX A5500:

Sat Jun 29 18:12:42 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX A5500               Off |   00000000:01:00.0  On |                  Off |
| 30%   45C    P8             26W /  230W |     554MiB /  24564MiB |     22%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      2311      G   /usr/lib/xorg/Xorg                            182MiB |
|    0   N/A  N/A      2497      G   /usr/bin/gnome-shell                          293MiB |
+-----------------------------------------------------------------------------------------+

Interestingly I do not get exactly the same learning performance. That can mean there is a bug somewhere, however I did test on Cartpole-v1, Acrobot-v1and Mountaincar-v0 and it successfully learns these envs.

Let me know if you are interested for a PR on this.