joshspeagle / dynesty

Dynamic Nested Sampling package for computing Bayesian posteriors and evidences
https://dynesty.readthedocs.io/
MIT License
347 stars 76 forks source link

first attempt to be able to deal with plateaus #414

Closed segasai closed 1 year ago

segasai commented 1 year ago

The issue raised in #412 is at least partially related to how we deal with likelihood plateaus and/or points that are present several times (like when rwalk proposals are not accepted).

This PR is a first attempt to deal with this. The idea here is that in say static sampling case, if we ever hit a point where all the logl values are the same the right thing to do is to stop iterating, and execute add_live() on the leftover live points assuming they are uniformly distributed with the volume and that's it + issue a warning.

I.e. for this test case of a likelihood of a uniform ball inside a cube the current dynesty will either run an infinite loop or throw an error. While with this PR. you get a warning 'We have reached the plateau in the likelihood we are stopping sampling' and we get a correct logz value. In fact even dynamic sampling works correctly in this case (despite spitting the warnings many times)

import numpy as np
import dynesty
import scipy.special

nlive = 100

S = 3
R = 1

ndim = 2
A0 = 1
A1 = 10

# likelihood that has value A1 inside a sphere with the radius R
# and outside it has velue A0
def loglike_inf(x):
    r = np.sqrt(np.sum(x**2))
    if r < R:
        ret = np.log(A1)  #- 1e-6 * r
    else:
        ret = np.log(A0) - 1e-6 * r
    # print(ret, r)
    return ret

# true value of the integral
LOGZ_TRUE = np.log(A0 + np.pi**(ndim / 2.) /
                   scipy.special.gamma(ndim / 2. + 1) * R**ndim * (A1 - A0) /
                   ((2 * S)**ndim))

def prior_transform(x):
    return (2 * x - 1) * S

def testx():
    rstate = np.random.default_rng(1)
    sample = 'rslice'
    sampler = dynesty.NestedSampler(loglike_inf,
                                           prior_transform,
                                           ndim,
                                           nlive=nlive,
                                           rstate=rstate,
                                           bound='none',
                                           sample=sample)
    sampler.run_nested(print_progress=True)
    res = sampler.results
    print(res.logz[-1], res.logzerr[-1])
    print('True logz', LOGZ_TRUE)
    return res

if __name__ == '__main__':
    res = testx()

This particular example deals with the case where we hit a plateau at the peak of likelihood. If I change the the likelihood to have a plateau in the beginning of the sampling I get the wrong logz for the same reason as #412 . But it looks like the solution is easy we just need to use the right delta log volume in this case similar to what add_live() does. I.e. if we just started sampling and n points out of N live points have the same (very low) likelihood. The right volume associated with each of those will be 1/N.

segasai commented 1 year ago

The current patches I believe correctly implement the logic for the static sampler and are able to correctly estimate evidence for this case of a uniform ball inside a cube.

import numpy as np
import dynesty
import scipy.special

nlive = 1000

S = 3
R = 1

ndim = 2
A0 = 1
A1 = 10

# likelihood that has value A1 inside a sphere with the radius R
# and outside it has velue A0
def loglike_inf(x):
    r = np.sqrt(np.sum(x**2))
    if r < R:
        ret = np.log(A1)  # - 1e-6 * r
    else:
        ret = np.log(A0)  # - 1e-6 * r
    # print(ret, r)
    return ret

# true value of the integral
LOGZ_TRUE = np.log(A0 + np.pi**(ndim / 2.) /
                   scipy.special.gamma(ndim / 2. + 1) * R**ndim * (A1 - A0) /
                   ((2 * S)**ndim))

def prior_transform(x):
    return (2 * x - 1) * S

def testx():
    rstate = np.random.default_rng(1)
    sample = 'rslice'
    sampler = dynesty.NestedSampler(loglike_inf,
                                    prior_transform,
                                    ndim,
                                    nlive=nlive,
                                    rstate=rstate,
                                    bound='none',
                                    sample=sample)
    sampler.run_nested(print_progress=True)
    res = sampler.results
    print(res.logz[-1], res.logzerr[-1])
    print('True logz', LOGZ_TRUE)
    return res

if __name__ == '__main__':
    res = testx()

We'll see whether all the tests pass.

If I just use the DynamicNestedSampler I get a wrong result and this is presumably caused by incorrect merging ( _merge_two ) that needs to learn how to deal with points with same logl values.

coveralls commented 1 year ago

Pull Request Test Coverage Report for Build 3782157321


Changes Missing Coverage Covered Lines Changed/Added Lines %
py/dynesty/dynamicsampler.py 71 74 95.95%
<!-- Total: 144 147 97.96% -->
Totals Coverage Status
Change from base Build 3596723915: 0.2%
Covered Lines: 4150
Relevant Lines: 4516

💛 - Coveralls
joshspeagle commented 1 year ago

Just leaving this here as a reference since it's related to this issue: https://arxiv.org/abs/2010.13884

segasai commented 1 year ago

Yes, I've seen this paper in the past. I don't think I've found the way they talk about it very helpful (for me at least). The way I treated this is purely fixing the map between point i and the volume V_i associated with it, no other changes. In fact We don't need to talk about likelihood, weights etc. All that stays the same.

And the fix for the volumes is to just switch to linear volume decrease throughout the plateau and then return to normal.

segasai commented 1 year ago

One thing I have just checked that is unfortunate and maybe a flaw in the approach is the behaviour when some points are repeated. I.e. when doing random walks if we don't accept the proposal we just add a starting point again.
I never quite knew/calculated what should happens with volumes in this case, but I've just done a toy model (below) of sampling a n-d ball, where occasionally instead of proposing a new point I just choose one of the existing point at random ( thus creating a duplicate mimicking the unsuccesfull random walk), and it seems that in that case the volume shrinkage per point seem to stay the same, while with the current patch it will be treated as a plateau and will have a linear shrinkage.

import numpy as np
import dynesty
import scipy.special

def logvol(r):
    ndim = r.shape[-1]
    return ndim / 2. * np.log(
        np.pi) + ndim * np.log(r) - scipy.special.gammaln(ndim / 2. + 1)

def sample_ball(n, ndim, R, rstate):
    # cdf (r/R)^n                                                                                                             
    ys = rstate.uniform(size=n)
    rs = ys**(1. / ndim) * R
    xs = rstate.normal(size=(n, ndim))
    xs = xs / np.sqrt((xs**2).sum(axis=1))[:, None]
    return xs * rs[:, None]

def sampler(nit, ndim, nlive, rstate, pbrepeat=0.2):
    pts0 = sample_ball(nlive, ndim, 1, rstate)
    res = []
    for i in range(nit):
        rad = np.sum(pts0**2, axis=1)**.5
        worst = np.argmax(rad)
        res.append(pts0[worst] * 1.)
        newsample = sample_ball(1, ndim, rad[worst], rstate)
        randsampleid = rstate.permutation(np.nonzero(rad < rad[worst])[0])[0]
        if rstate.uniform() < pbrepeat:
            # repeat one of the points 
            newx = pts0[randsampleid]
        else:
            # use a new point
            newx = newsample
        pts0[worst] = newx
    res = np.array(res)
    return res

I don't quite know how serious this is, because treating duplicates would require some additional tracking.

segasai commented 1 year ago

I've also tested the current patch on the wedding cake function from the Fowlie's paper and it works fine.

import numpy as np
import dynesty
import dynesty.utils as dyutil
from utils import get_rstate, get_printing
import pytest
import scipy.special

ndim = 10

printing = get_printing()

sig = 0.2
alpha = .7

# Wedding cake function from Fowlie 2020
def loglike_inf(x):
    D = len(x)
    r = np.max(np.abs(x - 0.5))
    i = (D * np.log(2 * r) / np.log(alpha)).astype(int)
    logp = -(alpha**(2 * i / D)) / (8 * sig**2)
    return logp

# true value of the integral
ndim = 2
LOGZ_TRUE = scipy.special.logsumexp(-alpha**(2 * np.arange(100) / ndim) /
                                    (8 * sig**2) +
                                    np.arange(100) * np.log(alpha) +
                                    np.log((1 - alpha)))

def prior_transform(x):
    return x

# here are are trying to test different stages of plateau
# probing with different dlogz's
@pytest.mark.parametrize('sample,dlogz', [('unif', 1), ('rwalk', 1),
                                          ('rslice', 1), ('unif', .01),
                                          ('rwalk', .01), ('rslice', .01)])
def test_static(sample, dlogz):
    nlive = 1000
    rstate = get_rstate()
    sampler = dynesty.NestedSampler(loglike_inf,
                                    prior_transform,
                                    ndim,
                                    nlive=nlive,
                                    rstate=rstate,
                                    bound='none',
                                    sample=sample)
    sampler.run_nested(print_progress=printing, dlogz=dlogz)
    res = sampler.results
    THRESH = 3
    print(res.logz[-1], LOGZ_TRUE)
    assert np.abs(res.logz[-1] - LOGZ_TRUE) < THRESH * res.logzerr[-1]

@pytest.mark.parametrize('sample,', ['unif', 'rslice', 'rwalk'])
def test_dynamic(sample):
    rstate = get_rstate()
    nlive = 100
    sampler = dynesty.DynamicNestedSampler(loglike_inf,
                                           prior_transform,
                                           ndim,
                                           nlive=nlive,
                                           rstate=rstate,
                                           bound='none',
                                           sample=sample)
    sampler.run_nested(print_progress=printing)
    res = sampler.results
    THRESH = 3
    print(res.logz[-1], LOGZ_TRUE)
    assert np.abs(res.logz[-1] - LOGZ_TRUE) < THRESH * res.logzerr[-1]

test_static('unif', 0.1)

@joshspeagle do you have some time to take a look at this? I think this is a big change (although the patch is actually very small), I would want to have another opinion here, before pushing this.

joshspeagle commented 1 year ago

I'll take a look at this later today. I expect the default behaviour is probably fine but will need to see if any additional logic needs to be introduced into the random prior volume sampling functions to match if that hasn't been added in yet.

segasai commented 1 year ago

Thanks! I agree that the initial sampling needs to be updated as well (specifically for #412), but that's a pretty small change comparing to the rest

joshspeagle commented 1 year ago

Okay, I think the proposed changes are good and are theoretically-motivated (i.e. switching from exponential to uniform compression as we randomly sample within the plateau). There might be some small hickups (as mentioned above), but my inspection of the code changes look good to merge to me.

segasai commented 1 year ago

Thanks for taking a look. Given the go ahead, I'll see if I can add a initial sampling fix. I'm now realizing that may be a bit more tricky, as I think there are many bits of code that have an assumption of logvol=0 at very beginning.

joshspeagle commented 1 year ago

Yea...

I'm also not sure if there have to be any logic changes in the utility functions for jittering/resampling/merging runs to capture the different behaviour in the plateau regions. The existing logic already has some checks for this to handle the expected dynamic sampling behaviour, but might be a bit more tags/checks to deal with the plateaus.

segasai commented 1 year ago

The merge runs I have already patched so I think it is good with plateaus. I even added a test specifically for that. Maybe things like unravel needs to be adjusted though. I was thinking if the initialize_points can be just part of the regular sampler, as we start by sampling uniformly within a cube anyway otherwise we have to compute the logvol0 there and pass it when initializing samplers

segasai commented 1 year ago

The last patches fix the #412