Closed Vysakh13579 closed 10 months ago
Thanks for finding this and the example. Very interesting, I have some hypothesis and will see if I can sort it out soon.
Reproduced locally.
I'll dig into this tomorrow, or as soon as possible, however this problem disappears if you JIT or AOT compile the function:
...
def nested_sampling(key):
...
ns_compiled = jax.jit(nested_sampling).lower(random.PRNGKey(0)).compile()
ram_py = []
ram_py.append(python_process.memory_info()[0] / (1024 ** 3))
for i in range(5):
ns_compiled(random.PRNGKey(i))
ram_py.append(python_process.memory_info()[0] / (1024 ** 3))
print(ram_py[-1]) # Stays the same
It might be related to caching in jaxns.framework
however if this was present in JAXNS 2.2 then that makes me wonder if it's not something else. Anyways, great spot @Vysakh13579!
Resolved. The growing memory usage is the JAX caching. The cache can grow because many operations have an implicit JIT-compile, e.g. while loops. So calling JAXNS code without compiling at the outer level will cause the internal implicit compilatios to occur over and over and this will be slower and cause memory growth. The JAX team recently introduced a way to clear the cache.
Sticking a jax.clear_cache()
after runnning your code resolves the growing memory issue.
Also, as noted above, if you JIT or AOT compile the run it'll only run once.
Removed bug label because it's not a bug, but a "feature" of JAX.
Describe the bug As part of my work, I found it necessary to execute nested sampling for a model within a loop multiple times. However, I observed a persistent increase in RAM usage after each run, even though I deleted any variables created within the loop. To investigate whether this issue stemmed from my own code or was a broader concern, I conducted a test using the example provided in the JAXNS documentation for the constant likelihood function. Surprisingly, the same behaviour was observed – the RAM usage continued to rise with each iteration of the loop. I would appreciate it if you could take a look into this matter and provide guidance on potential solutions or insights.
In version 2.2.6, I attempted to pinpoint the source of a memory leak using a memory profiler. Through this investigation, I traced the issue back to the fresh_run call within the call function for approximate nested sampling. Upon further analysis, it seems that the problem may be rooted in the StaticNestedSampler class, although I cannot confirm this with absolute certainty. I wanted to bring this to your attention in the hope that it helps in identifying and resolving the memory leak.
I have provided the code I used for monitoring the RAM for the constant likelihood example below for the 2.3.0 version below. I stumbled upon this issue in the 2.2.6 version. I have also attached the plot showing the increase in RAM usage after each iteration too.
JAXNS version 2.3.0