instadeepai / flashbax

⚡ Flashbax: Accelerated Replay Buffers in JAX
https://instadeepai.github.io/flashbax/
Apache License 2.0
213 stars 11 forks source link

add is slow on GPU #31

Open garymm opened 4 months ago

garymm commented 4 months ago

The benchmarks show adding to a buffer is very slow on GPU (13 ms vs 0.4 ms for reverb or stable baselines, over 30x slower). Has anyone filed a bug against Jax about this?

EdanToledo commented 3 months ago

Hello, so we haven't specifically asked the JAX maintainers about this issue. However, important to note for reverb and stable baselines that the memory is not stored on the GPU so its not really a fair comparison and it would be better to look at the CPU times. If you find the GPU add times too slow and you're not doing a fully jitted training loop then you can ensure that the flashbax buffer is stored on the CPU.

garymm commented 3 months ago

Thanks for the reply. Even comparing to flashbax on TPU, it's much much slower on GPU so might be worth filing a bug about that with JAX? I'm assuming the source data is already on the GPU when you're adding?

EdanToledo commented 3 months ago

Yes i believe so, i did the benchmarks a while ago but I'm sure i would have created the data on device.

eadadi commented 2 months ago

Hi, I would like to ask for more details regarding where we stand on in this situation.

  1. Currently, GPU speeds for adding single timesteps is bad? can we point where the delay happens? From this discussion I understand that it simply the jax operation that is used?
  2. For adding batch of timesteps we don't have these delays right?
  3. Is there anything we can do to improve the situation?
sash-a commented 4 weeks ago
  1. Yes it's likely the underlying XLA that JAX is compiled to
  2. Less delays, but still could likely be better
  3. One possible improvement is that we could inform JAX that we are using unique_indices and indices_are_sorted during our .at[].set() over here for example. I'm not sure if this would help, but the docs imply that it might. Unfortunately I don't quite have time to test this right now