JohannesBuchner / UltraNest

Fit and compare complex models reliably and rapidly. Advanced nested sampling.
https://johannesbuchner.github.io/UltraNest/
Other
142 stars 30 forks source link

Vectorised sampling and memory consumption #114

Closed LucaMantani closed 3 months ago

LucaMantani commented 9 months ago

Description

I am using ultranest to sample from a posterior which is essentially a multidimensional gaussian (~20-30 parameters) defined in jax and jax.jit compiled. The code is working well when I do not have vectorisation. When I turn it on, it gets much faster but I observe a rather severe increase in memory consumption when I run with 20 parameters (around 14 Gb). When I increase the dimensionality of the problem to 30 parameters, the memory consumption becomes excessive (around 60 Gb).

None of this happens when vectorised=False.

What I Did

I tried setting a maximum number of draws with ndraw_max = 500. The issue seems tamed, but still the memory consumption grows with the number of likelihood calls (around 30 Gb).

sampler = ultranest.ReactiveNestedSampler(
            parameters,
            log_likelihood_vectorised,
            weight_minimization_prior,
            vectorized=True,
            ndraw_max=500
        )
comane commented 6 months ago

@JohannesBuchner I observed the same problem. Running with vectorized=True on a multicore machine lead to an excessive use of memory. In my case for a fit with 30 parameters on a multicore machine the memory consumption reached >= 200 Gb.

JohannesBuchner commented 6 months ago

Isn't that your likelihood though? You can test with something like:

ndraw_max = 500
for i in range(100):
     us = np.random.uniform(size=(ndraw_max, ndim))
     ps = prior_transform(us)
     Ls = log_likelihood(ps)

(from here)

LucaMantani commented 6 months ago

I ran the tests:

ndraw_max = 500
us = np.random.uniform(size=(ndraw_max, 15))

%timeit prior_transform(us)

6.24 µs ± 21.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

ndraw_max = 500
us = np.random.uniform(size=(ndraw_max, 15))

p = prior_transform(us)

%timeit log_likelihood(p)

41.4 ms ± 656 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Then I ran the suggested loop

ndraw_max = 500
for i in range(100000):
    us = np.random.uniform(size=(ndraw_max, 15))
    ps = prior_transform(us)
    Ls = log_likelihood(ps)

Here I do not see problem with memory, the python process needs ~1 GB, and the memory usage does not increase with time, contrary to what I observe when I run ultranest.

JohannesBuchner commented 6 months ago

I guess however that running a vectorized gaussian likelihood with ultranest would not show this extreme memory usage.

So maybe there is a memory leak somewhere. You may need to use some python memory trace tools.

JohannesBuchner commented 3 months ago

Please reopen if you can reproduce this issue with a non-jax toy likelihood function.

This page suggests you can use ulimit or prlimit to limit the memory allowance of a program.