Open RadostW opened 3 days ago
A couple potentially helpful tips:
(1) You can use jax.live_arrays()
https://jax.readthedocs.io/en/latest/_autosummary/jax.live_arrays.html to return all of the live arrays and check for potential leaks.
(2) Does manually running gc.collect()
help with the problem?
Description
When using
jax
based packagepychastic
(an SDE solver) jax backend keeps eating memory indefinitely.Output (abbreviated)
I apologize for the contrived code to reproduce the issue. I'd be happy to chase the leak further, but I'm unfamiliar with any tools that could help diagnose the issue. Is there some way to see what's taking up all this space?
System info (python version, jaxlib version, accelerator, etc.)
also tested (same result) with