nespinoza / juliet

A versatile modelling tool for transiting and non-transiting (single and multiple) exoplanetary systems
MIT License
52 stars 31 forks source link

Multiprocessing error with dynesty #105

Closed LucaNap closed 7 months ago

LucaNap commented 7 months ago

Hi. I have installed the new version of juliet (2.2.3), which includes dynesty (2.1.3). However, I have noticed that now I can't use multiple threads with dynesty because of a recursion depth error:

File ~\.conda\envs\julietto2\lib\site-packages\juliet\fit.py:1729, in fit.__init__(self, data, sampler, n_live_points, nwalkers, nsteps, nburnin, emcee_factor, ecclim, pl, pu, ta, nthreads, use_ultranest, use_dynesty, dynamic, dynesty_bound, dynesty_sample, dynesty_nthreads, dynesty_n_effective, dynesty_use_stop, dynesty_use_pool, **kwargs)
   1726 # Now run all with multiprocessing:
   1727 with contextlib.closing(Pool(processes=self.nthreads -
   1728                              1)) as executor:
-> 1729     sampler = DynestySampler(self.loglike,
   1730                              self.prior_transform_r,
   1731                              self.data.nparams,
   1732                              pool=executor,
   1733                              queue_size=self.nthreads,
   1734                              **d_args)
   1735     sampler.run_nested(**ds_args)
   1736     results = sampler.results

File ~\.conda\envs\julietto2\lib\site-packages\dynesty\dynesty.py:677, in NestedSampler.__new__(cls, loglikelihood, prior_transform, ndim, nlive, bound, sample, periodic, reflective, update_interval, first_update, npdim, rstate, queue_size, pool, use_pool, live_points, logl_args, logl_kwargs, ptform_args, ptform_kwargs, gradient, grad_args, grad_kwargs, compute_jac, enlarge, bootstrap, walks, facc, slices, fmove, max_move, update_func, ncdim, blob, save_history, history_filename)
    674     kwargs['grad'] = grad
    675     kwargs['compute_jac'] = compute_jac
--> 677 live_points, logvol_init, init_ncalls = _initialize_live_points(
    678     live_points,
    679     ptform,
    680     loglike,
    681     M,
    682     nlive=nlive,
    683     npdim=npdim,
    684     rstate=rstate,
    685     blob=blob,
    686     use_pool_ptform=use_pool.get('prior_transform', True))
    688 # Initialize our nested sampler.
    689 sampler = super().__new__(_SAMPLERS[bound])

File ~\.conda\envs\julietto2\lib\site-packages\dynesty\dynamicsampler.py:438, in _initialize_live_points(live_points, prior_transform, loglikelihood, M, nlive, npdim, rstate, blob, use_pool_ptform)
    436     cur_live_v = map(prior_transform, np.asarray(cur_live_u))
    437 cur_live_v = np.array(list(cur_live_v))
--> 438 cur_live_logl = loglikelihood.map(np.asarray(cur_live_v))
    439 if blob:
    440     cur_live_blobs = np.array([_.blob for _ in cur_live_logl])

File ~\.conda\envs\julietto2\lib\site-packages\dynesty\utils.py:177, in LogLikelihood.map(self, pars)
    171     ret = list([
    172         LoglOutput(_, self.blob) for _ in map(self.loglikelihood, pars)
    173     ])
    174 else:
    175     ret = [
    176         LoglOutput(_, self.blob)
--> 177         for _ in self.pool.map(self.loglikelihood, pars)
    178     ]
    179 if self.save:
    180     self.history_append([_.val for _ in ret], pars)

File ~\.conda\envs\julietto2\lib\multiprocessing\pool.py:364, in Pool.map(self, func, iterable, chunksize)
    359 def map(self, func, iterable, chunksize=None):
    360     '''
    361     Apply `func` to each element in `iterable`, collecting the results
    362     in a list that is returned.
    363     '''
--> 364     return self._map_async(func, iterable, mapstar, chunksize).get()

File ~\.conda\envs\julietto2\lib\multiprocessing\pool.py:771, in ApplyResult.get(self, timeout)
    769     return self._value
    770 else:
--> 771     raise self._value

RecursionError: maximum recursion depth exceeded while calling a Python object
nespinoza commented 7 months ago

Hi @LucaNap,

Can you provide a minimum working example for this problem? I've been performing fits with the new version using multi-threading with dynesty and see no issues so far. This is with the latest version and dynesty 2.1.3.

N.

LucaNap commented 7 months ago

Hey @nespinoza. Thanks for the quick response.

For consistency, I have used the tutorial code ("Joint transit and radial-velocity fits"). Apparently, setting nthreads>1 indeed works with this code. However, the error pops up again if I use lightkurve for the lc input:

lc = lightkurve.search_lightcurve('TOI 141')[0].download().normalize() lc = lc[ lc.flux > 0] times['TESS'], fluxes['TESS'], fluxes_error['TESS'] = lc.time.value+2457000, lc.flux.value, lc.flux_err.value

EDIT: The error goes away using np.array() for both lc.flux and lc.flux_err values! I guess this error has something to do with the "new" MaskedNDArray fluxes used in lightkurve.

P.S. This is on a windows 11 machine with python=3.8 installed on conda, along with juliet and lightkurve only (latest versions).

nespinoza commented 7 months ago

Interesting! This seems like lightkurve perhaps not closing some pools on Windows? In any case, given this is not a direct juliet issue (at least for now), I will close this comment --- but please if you or anyone else finds this is a juliet issue, feel free to bring this back to my attention!

N.