JohannesBuchner / UltraNest

Fit and compare complex models reliably and rapidly. Advanced nested sampling.
https://johannesbuchner.github.io/UltraNest/
Other
142 stars 30 forks source link

Memory leak when using UltraNest with tinygp #84

Closed zairving closed 1 year ago

zairving commented 1 year ago

Description

I'm trying to use UltraNest to perform model selection with tinygp models, but when I run my script it slowly eats all the available memory in my machine (unless it converges before my system runs out of memory). I've found two stopgap solutions, but both are flawed:

1) JIT-compiling the likelihood function (i.e., with a @jax.jit decorator) seems to solve the issue, but this restricts me to jittable likelihood functions (which I can't use in all cases).

2) I've written a bash script which will run my script and check how much memory it's occupying every 30 seconds; if the memory usage gets too high, it kills the process and starts the script again (resuming from log_dir). This is also problematic because it requires the points in log_dir to be updated before the process is killed - which not guaranteed when the sampling is slow and/or efficiency is low.

I've used both tinygp and UltraNest independently without issue, but together something seems to be going wrong.

What I Did

Here's an example script I adapted from a tutorial in tinygp's documentation which illustrates my problem:

import numpy as np
import matplotlib.pyplot as plt
import tinygp
import jax
from ultranest import ReactiveNestedSampler

jax.config.update("jax_enable_x64", True)  # enable double precision in JAX
jax.config.update("jax_platform_name", "cpu")  # force JAX to run on CPU

def kernel(theta):
    return theta[0]**2*tinygp.kernels.ExpSquared(theta[1])*tinygp.kernels.Cosine(theta[2]) + theta[3]**2*tinygp.kernels.Matern32(theta[4])

def build_gp(theta):
    return tinygp.GaussianProcess(kernel(theta), x, diag=yerr**2 + 1e-8, mean=theta[5])

def loglike(params):
    return build_gp(params).log_probability(y)

def transform(cube):

    priors = cube.copy()

    # uniform priors for all parameters

    lo, hi = 0, max(y) - min(y)
    priors[0] = cube[0]*(hi - lo) + lo  # ExpSquared amplitude
    priors[3] = cube[3]*(hi - lo) + lo  # Matern-3/2 amplitude

    lo, hi = 1e-1, 1e1
    priors[1] = cube[1]*(hi - lo) + lo  # ExpSquared length-scale
    priors[2] = cube[2]*(hi - lo) + lo  # cosine period

    lo, hi = 1, 100
    priors[4] = cube[4]*(hi - lo) + lo  # Matern-3/2 length-scale

    lo, hi = min(y), max(y)
    priors[5] = cube[5]*(hi - lo) + lo  # mean

    return priors

# generate some data
random = np.random.default_rng(42)
x = np.sort(
    np.append(
        random.uniform(0, 3.8, 28),
        random.uniform(5.5, 10, 18),
    )
)
yerr = random.uniform(0.08, 0.22, len(x))
y = (
    0.2 * (x - 5)
    + np.sin(3 * x + 0.1 * (x - 5) ** 2)
    + yerr * random.normal(size=len(x))
)
true_x = np.linspace(0, 10, 100)
true_y = 0.2 * (true_x - 5) + np.sin(3 * true_x + 0.1 * (true_x - 5) ** 2)

# create figure
fig, ax = plt.subplots(tight_layout=True, dpi=200)

#plot data
ax.errorbar(x, y, yerr, marker="x", capsize=2, color="black", linestyle="None", label="data")
ax.plot(true_x, true_y, "k-", alpha=.5, label="true")

# define sampler
sampler = ReactiveNestedSampler(param_names=["ExpSquared amp", "ExpSquared length-scale", "period",
                                             "Matern-3/2 amplitude", "Matern-3/2 length-scale", "mean"],
                                loglike=loglike,
                                transform=transform,
                                log_dir="/change/me/",
                                resume="overwrite")

# run sampler
results = sampler.run()
sampler.print_results()

# get medians of posterior distributions
theta = results["posterior"]["median"]

# create GP
gp = build_gp(theta)

# condition GP
cond_gp = gp.condition(y, true_x).gp
mu, std = cond_gp.mean, np.sqrt(cond_gp.variance)

# plot GP predictive mean and 2 sigma confidence interval
ax.plot(true_x, mu, "r-", label="GP prediction")
ax.fill_between(true_x, mu+2*std, mu-2*std, color="lightgrey", alpha=.5, label="$2 \\sigma$ confidence")

# label plot
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.legend()

plt.show()
JohannesBuchner commented 1 year ago

Can you try sampling 100000 points from the prior, and evaluating their likelihood?

for i in range(10000):
    p = [prior_transform(np.random.uniform(size=ndim)) for i in range(100)]
    L = [log_likelihood(pi) for pi in p]
    print(max(L))

see https://johannesbuchner.github.io/UltraNest/debugging.html#Finding-model-bugs

zairving commented 1 year ago

Hi Johannes,

Thanks for the quick, and helpful, reply!

Doing as you suggested, it seems the problem is with tinygp. Without jitting loglike, I got the same runaway memory usage issue, but jitting seems to plug the leak. I'll open an issue on the tinygp repository instead.