Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Other
142 stars 10 forks source link

RAM usage building up when sampling is conducted multiple times sequentially. #108

Closed Vysakh13579 closed 10 months ago

Vysakh13579 commented 10 months ago

Describe the bug As part of my work, I found it necessary to execute nested sampling for a model within a loop multiple times. However, I observed a persistent increase in RAM usage after each run, even though I deleted any variables created within the loop. To investigate whether this issue stemmed from my own code or was a broader concern, I conducted a test using the example provided in the JAXNS documentation for the constant likelihood function. Surprisingly, the same behaviour was observed – the RAM usage continued to rise with each iteration of the loop. I would appreciate it if you could take a look into this matter and provide guidance on potential solutions or insights.

In version 2.2.6, I attempted to pinpoint the source of a memory leak using a memory profiler. Through this investigation, I traced the issue back to the fresh_run call within the call function for approximate nested sampling. Upon further analysis, it seems that the problem may be rooted in the StaticNestedSampler class, although I cannot confirm this with absolute certainty. I wanted to bring this to your attention in the hope that it helps in identifying and resolving the memory leak.

I have provided the code I used for monitoring the RAM for the constant likelihood example below for the 2.3.0 version below. I stumbled upon this issue in the 2.2.6 version. I have also attached the plot showing the increase in RAM usage after each iteration too.

import tensorflow_probability.substrates.jax as tfp
from jax import random

from jaxns import DefaultNestedSampler
from jaxns import Model
from jaxns import Prior

import psutil
import os
tfpd = tfp.distributions
import matplotlib.pyplot as plt

def log_likelihood(theta):
    return 0.

def prior_model():
    x = yield Prior(tfpd.Uniform(0., 1.), name='x')
    return x

def nested_sampling():
    model = Model(prior_model=prior_model,
                log_likelihood=log_likelihood)

    exact_ns = DefaultNestedSampler(model=model, max_samples=1e4)

    termination_reason, state = exact_ns(random.PRNGKey(42))
    results = exact_ns.to_results(termination_reason=termination_reason, state=state)

    del model, exact_ns, termination_reason, state, results

if __name__=='__main__':
    pid = os.getpid()
    python_process = psutil.Process(pid)

    ram_py = []
    ram_py.append(python_process.memory_info()[0]/(1024**3))
    for i in range(20):
        nested_sampling()
        ram_py.append(python_process.memory_info()[0]/(1024**3))

    plt.plot(ram_py, 'k.-')
    plt.xlabel('runs', fontsize=12)
    plt.ylabel('python RAM usage(GB)', fontsize=12)
    plt.show()

JAXNS version 2.3.0

jaxnsRAM

Joshuaalbert commented 10 months ago

Thanks for finding this and the example. Very interesting, I have some hypothesis and will see if I can sort it out soon.

Joshuaalbert commented 10 months ago

Reproduced locally.

Joshuaalbert commented 10 months ago

I'll dig into this tomorrow, or as soon as possible, however this problem disappears if you JIT or AOT compile the function:

...
def nested_sampling(key):
  ...

ns_compiled = jax.jit(nested_sampling).lower(random.PRNGKey(0)).compile()

ram_py = []
ram_py.append(python_process.memory_info()[0] / (1024 ** 3))
for i in range(5):
    ns_compiled(random.PRNGKey(i))
    ram_py.append(python_process.memory_info()[0] / (1024 ** 3))
    print(ram_py[-1]) # Stays the same

It might be related to caching in jaxns.framework however if this was present in JAXNS 2.2 then that makes me wonder if it's not something else. Anyways, great spot @Vysakh13579!

Joshuaalbert commented 10 months ago

Resolved. The growing memory usage is the JAX caching. The cache can grow because many operations have an implicit JIT-compile, e.g. while loops. So calling JAXNS code without compiling at the outer level will cause the internal implicit compilatios to occur over and over and this will be slower and cause memory growth. The JAX team recently introduced a way to clear the cache.

Sticking a jax.clear_cache() after runnning your code resolves the growing memory issue.

Also, as noted above, if you JIT or AOT compile the run it'll only run once.

Joshuaalbert commented 10 months ago

Removed bug label because it's not a bug, but a "feature" of JAX.