Open stergiosba opened 5 months ago
Interesting! I'm not sure why that would happen. However, I do think the flax
RNN documentation/structure may have changed after version 0.4.27, which could be why you're getting the error.
https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/rnncell_upgrade_guide.html
Well there is no error technically speaking. The code runs fine and the model trains fine. The only problem is that compilation on the XLA side of things takes longer. Version 7.4 of XLA (Jax 0.4.26) is much faster than 8.3 (Jax 0.4.28) in generating the code for a GPU device (hope this add clarity to the issue). Also I face this issue in two different GPUs on two different machines.
Anyways, maybe this will alert you on your future JAX endeavors. Thanks again.
I ended up fixing this by making the StackedEncoderModel
a scanned version of what you initially had.
There are some minimal code changes for the end user which maybe we can fix.
Some compilation benchmarks:
I also attach the results from setting jax.config.update("jax_log_compiles", True)
20 layers S5:
Finished jaxpr to MLIR module conversion jit(train) in 1.3401036262512207 sec
Finished XLA compilation of jit(train) in 29.01173186302185 sec vs
200 layers S5:
Finished jaxpr to MLIR module conversion jit(train) in 1.3581228256225586 sec
2024-06-29 18:11:22.686136: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 8.90GiB (9555618931 bytes) by rematerialization; only reduced to 10.13GiB (10874848652 bytes), down from 10.13GiB (10875108940 bytes) originally
Finished XLA compilation of jit(train) in 29.203452587127686 sec
Tests where run on a single NVIDIA RTX A5500:
Sat Jun 29 18:12:42 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07 Driver Version: 550.90.07 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX A5500 Off | 00000000:01:00.0 On | Off |
| 30% 45C P8 26W / 230W | 554MiB / 24564MiB | 22% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 2311 G /usr/lib/xorg/Xorg 182MiB |
| 0 N/A N/A 2497 G /usr/bin/gnome-shell 293MiB |
+-----------------------------------------------------------------------------------------+
Interestingly I do not get exactly the same learning performance. That can mean there is a bug somewhere, however I did test on Cartpole-v1
, Acrobot-v1
and Mountaincar-v0
and it successfully learns these envs.
Let me know if you are interested for a PR on this.
Hey, thanks for providing purejaxrl is pretty awesome.
I have used the experimental
S5
code that you provide for a part of my research and after version 0.4.27 (same for 0.4.28) ofjaxlib
I have been getting 5 times longer compilation times when I increase then_layers
of theS5
. Any ideas why this might happen?