Open garymm opened 4 months ago
Hello, so we haven't specifically asked the JAX maintainers about this issue. However, important to note for reverb and stable baselines that the memory is not stored on the GPU so its not really a fair comparison and it would be better to look at the CPU times. If you find the GPU add times too slow and you're not doing a fully jitted training loop then you can ensure that the flashbax buffer is stored on the CPU.
Thanks for the reply. Even comparing to flashbax on TPU, it's much much slower on GPU so might be worth filing a bug about that with JAX? I'm assuming the source data is already on the GPU when you're adding?
Yes i believe so, i did the benchmarks a while ago but I'm sure i would have created the data on device.
Hi, I would like to ask for more details regarding where we stand on in this situation.
unique_indices
and indices_are_sorted
during our .at[].set()
over here for example. I'm not sure if this would help, but the docs imply that it might. Unfortunately I don't quite have time to test this right now
The benchmarks show adding to a buffer is very slow on GPU (13 ms vs 0.4 ms for reverb or stable baselines, over 30x slower). Has anyone filed a bug against Jax about this?