Closed LucaMantani closed 5 months ago
How about defining a pure functional likelihood, and a python wrapper likelihood which calls it but injects the data arguments? That should not add extra RAM if I followed your explanation correctly.
Ah, this is a nice idea! I tried in a snippet by doing something like:
class WrapLoglike():
def __init__(self, A):
self.A = A
def __call__(self, x):
return pure_loglike(x, self.A)
log_likelihood = WrapLogLike(A)
sampler = ultranest.ReactiveNestedSampler(
parameters,
log_likelihood,
bayesian_prior,
)
and it seemed to have worked indeed! Thanks!!
Good!
I am using ultranest with JAX, which thrives in a pure functional programming paradigm.
Related to issue #51 , ultranest does not currently accept extra arguments for the loglikelihood. This means that the loglike cannot be defined as a pure function, as for example it cannot receive data in input and it will rely on its definition in an external scope.
In general this is not a problem and one simply defines a function that returns a function, as outlined in https://github.com/JohannesBuchner/UltraNest/issues/51#issuecomment-1103661731.
The issue appears when using jax.jit, because these external objects, not being inputs, get compiled, i.e. copied in memory, causing a doubling of the memory consumption. Dealing with really large external variables (multi-dimensional matrices), this is resulting in a severe RAM overload.
Do you see a way around this? Would it be easy to add the extra_args?