google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

CUDA OOM with jax/pytorch notebook #489

Closed P-Schumacher closed 1 day ago

P-Schumacher commented 4 weeks ago

Hi, the notebook on the jax + torch tutorial is very nice and useful for me, but it uses a certain flag: os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

I understand this flag prevents a CUDA OOM issue, but it has been mentioned by the jax team that it also strongly slows down computation https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

I tried to remove it and solve the memory issues in other ways, but I haven't been successful so far. Is there any update from your team on this or maybe at least a current guess on where this memory leak is originating? Any kind of info would be very helpful.

erikfrey commented 1 day ago

Hi @P-Schumacher - back when we wrote this colab, we did not see significant change in performance with/without this particular flag. Our training workloads spend >99% of their time on device after initial setup, and don't release their DeviceArray buffers in a way that would cause deallocations / thrashing during training.

But I could be wrong! I'm going to close this for now, but if you find evidence that this flag is significantly impacting performance, please let us know.