JohannesBuchner / UltraNest

Fit and compare complex models reliably and rapidly. Advanced nested sampling.
https://johannesbuchner.github.io/UltraNest/
Other
153 stars 30 forks source link

Pure functions not allowed #139

Closed LucaMantani closed 5 months ago

LucaMantani commented 5 months ago

I am using ultranest with JAX, which thrives in a pure functional programming paradigm.

Related to issue #51 , ultranest does not currently accept extra arguments for the loglikelihood. This means that the loglike cannot be defined as a pure function, as for example it cannot receive data in input and it will rely on its definition in an external scope.

In general this is not a problem and one simply defines a function that returns a function, as outlined in https://github.com/JohannesBuchner/UltraNest/issues/51#issuecomment-1103661731.

The issue appears when using jax.jit, because these external objects, not being inputs, get compiled, i.e. copied in memory, causing a doubling of the memory consumption. Dealing with really large external variables (multi-dimensional matrices), this is resulting in a severe RAM overload.

Do you see a way around this? Would it be easy to add the extra_args?

JohannesBuchner commented 5 months ago

How about defining a pure functional likelihood, and a python wrapper likelihood which calls it but injects the data arguments? That should not add extra RAM if I followed your explanation correctly.

LucaMantani commented 5 months ago

Ah, this is a nice idea! I tried in a snippet by doing something like:

class WrapLoglike():
        def __init__(self, A):
            self.A = A

        def __call__(self, x):
            return pure_loglike(x, self.A)

log_likelihood = WrapLogLike(A)

sampler = ultranest.ReactiveNestedSampler(
    parameters,
    log_likelihood,
    bayesian_prior,
)

and it seemed to have worked indeed! Thanks!!

JohannesBuchner commented 5 months ago

Good!