joshspeagle / dynesty

Dynamic Nested Sampling package for computing Bayesian posteriors and evidences
https://dynesty.readthedocs.io/
MIT License
347 stars 76 forks source link

Special multiprocessing pool to avoid pickle overhead #393

Closed segasai closed 1 year ago

segasai commented 1 year ago

When using dynesty with multiprocessing and fast likelihood function but a lot of data it is easy to be hit by significant pickling overhead which defeats the purpose of the multiprocessing. It is easy to get around that with the simple wrapper that avoids sending the args/kwargs and functions over and over.

In the toy example below this speeds up things by a factor of 100 or so (try caching=True/False)

import dynesty
import numpy as np
from numpy import linalg
import multiprocessing as mp
import xpool
import time

class Model:

    def __init__(self, s=(1000, 1000)):
        self.ndim = 2
        self.arr = np.zeros(s)
        self.C2 = np.identity(self.ndim)
        self.Cinv2 = linalg.inv(self.C2)
        self.lnorm2 = -0.5 * (np.log(2 * np.pi) * self.ndim +
                              np.log(linalg.det(self.C2)))

    def __call__(self, x):
        """Multivariate normal log-likelihood."""
        time.sleep(0.0001)
        return -0.5 * np.dot(x, np.dot(self.Cinv2, x)) + self.lnorm2

    # prior transform
    def prior_transform(self, u):
        return 10. * (2. * u - 1.)

if __name__ == '__main__':
    maxiter = 2000
    M = Model()
    caching = True
    nthreads = 12
    prior_transform = M.prior_transform
    like = M
    rstate = np.random.default_rng(1)
    if caching:
        with xpool.Pool(nthreads, like, prior_transform) as pool:
            dsampler2 = dynesty.DynamicNestedSampler(pool.like,
                                                     pool.prior_transform,
                                                     nlive=50,
                                                     ndim=M.ndim,
                                                     bound='single',
                                                     sample='rslice',
                                                     rstate=rstate,
                                                     pool=pool)
            dsampler2.run_nested(maxiter=maxiter, use_stop=False)
    else:
        if nthreads > 1:
            pool = mp.Pool(nthreads)
        else:
            pool = None
            nthreads = None
        dsampler2 = dynesty.DynamicNestedSampler(like,
                                                 prior_transform,
                                                 nlive=50,
                                                 ndim=M.ndim,
                                                 bound='single',
                                                 sample='rslice',
                                                 rstate=rstate,
                                                 pool=pool,
                                                 queue_size=nthreads)

        dsampler2.run_nested(maxiter=maxiter, use_stop=False)

Here's the custom Pool wrapper

import multiprocessing as mp

class FunctionCache:
    like = None
    prior_transform = None
    logl_args = None
    logl_kwargs = None
    ptform_args = None
    ptform_kwargs = None

def initializer(like, prior_transform, logl_args, logl_kwargs, ptform_args,
                ptform_kwargs):
    FunctionCache.like = like
    FunctionCache.prior_transform = prior_transform
    FunctionCache.logl_args = logl_args
    FunctionCache.logl_kwargs = logl_kwargs
    FunctionCache.ptform_args = ptform_args
    FunctionCache.ptform_kwargs = ptform_kwargs

def like_cache(x):
    return FunctionCache.like(x, *FunctionCache.logl_args,
                              **FunctionCache.logl_kwargs)

def prior_transform_cache(x):
    return FunctionCache.prior_transform(x, *FunctionCache.ptform_args,
                                         **FunctionCache.ptform_kwargs)

class Pool:

    def __init__(self,
                 njobs,
                 like,
                 prior_transform,
                 logl_args=None,
                 logl_kwargs=None,
                 ptform_args=None,
                 ptform_kwargs=None):
        self.logl_args = logl_args
        self.logl_kwargs = logl_kwargs
        self.ptform_args = ptform_args
        self.ptform_kwargs = ptform_kwargs
        self.njobs = njobs
        self.like_0 = like
        self.prior_transform_0 = prior_transform
        self.like = like_cache
        self.prior_transform = prior_transform_cache

    def __enter__(self):
        initargs = (self.like_0, self.prior_transform_0, self.logl_args
                    or (), self.logl_kwargs or {}, self.ptform_args
                    or (), self.ptform_kwargs or {})
        self.pool = mp.Pool(self.njobs, initializer, initargs)
        return self

    def map(self, F, x):
        return self.pool.map(F, x)

    def __exit__(self, exc_type, exc_val, exc_tb):
        try:
            self.pool.close()
        except:
            pass
        try:
            self.pool.join()
        except:
            pass

    @property
    def size(self):
        return self.njobs

I am wondering if it worth including as dynesty.pool or something.

joshspeagle commented 1 year ago

I think adding in a dynesty.pool like this that most general users could benefit from would be great! I'd be in favour of adding this in.