Closed noahfranz13 closed 6 months ago
I don't have much current experience with shared memory multiprocessing and emcee/prospector. I think it has been used with dynesty (tagging @mjpark-astro), but this might work because of the use of partial
here:
https://github.com/bd-j/prospector/blob/main/prospect/fitting/fitting.py#L525C16-L525C16
Have you looked at the MPI + dynesty + prospector demo at https://github.com/bd-j/prospector/blob/main/demo/demo_mpi_params.py ?
Anyway, I think you are probably right that if it is trying to pickle the sps
object and communicate that will cause substantial slowdowns - that is a very large memory object generally.
I have seen the tutorial on using MPI with dynasty and it was very useful! Sadly, for my case (probably a very niche use case for prospector) I do need to use emcee for the backend.
As far as I can tell, I think using partial
to remove the kwargs from lnprob
for emcee should work too. I can try to implement it and open a PR sometime next week if you'd like?
Hi @noahfranz13 if you are set up to test and time that implementation, then yes a PR would be very helpful. Let me know if any questions arise. And thank you for the detailed information in your original comment.
@noahfranz13 I think we can close this now thanks to your PR. Much appreciated! But feel free to reopen if something is still not resolved.
Description
I've been finding that when using multiprocessing with emcee as the back end MCMC module there are significant slow downs rather than speed ups. I know this is a known issue (feature?) with emcee (https://emcee.readthedocs.io/en/stable/tutorials/parallel/#pickling-data-transfer-arguments) and through a quick look at the source code the issue does seem to be passing inputs into
kwargs
inemcee.EnsembleSampler
(https://github.com/bd-j/prospector/blob/30d2babb64c46137a7fcd504028db0ec7cf5a9ca/prospect/fitting/ensemble.py#L102C5-L102C5). So, if I'm understanding the prospector code correctly if thepool
argument andemcee=True, dynesty=False
is passed tofit_model
, bothkwargs
andpool
are passed toemcee.EnsembleSampler
, and there will be this slowdown as described in the emcee docs. Also, taking a quick look at the prospector source code an easy fix is not immediately apparent to me since user-defined global variables would not only need to be shared throughout a file but across the entire module.Is this a known issue with prospector? Or, could I be doing something incorrectly with prospector to create this slowdown? And, if not, are there thoughts (or plans already) to fix this?
Timing Results
As a MWE I modified the demo_params.py scipt to include the option for multiprocessing. The code for this is in the collapsible section at the bottom of this issue and includes just the code after
if __name__ == '__main__':
since that is all I modified indemo_params.py
. Using python version 3.11.4, prospector version 1.2.0, and emcee version 3.1.4 I found the following time results:With multiprocessing:
time python3 demo_params.py --mp --emcee --objid=0
prospect/models/priors.py:117: RuntimeWarning: divide by zero encountered in log, lnp = np.log(p)
which might point to an issue on my end or a deeper issue of some sort? Although, when I kill the process prospector/emcee does seem stuck at the multiprocessing step.Without multiprocessing:
time python3 demo_params.py --emcee --objid=0
MWE Code
if __name__ == '__main__': # - Parser with default arguments - parser = prospect_args.get_parser() # - Add custom arguments - parser.add_argument('--object_redshift', type=float, default=0.0, help=("Redshift for the model")) parser.add_argument('--add_neb', action="store_true", help="If set, add nebular emission in the model (and mock).") parser.add_argument('--add_duste', action="store_true", help="If set, add dust emission to the model.") parser.add_argument('--luminosity_distance', type=float, default=1e-5, help=("Luminosity distance in Mpc. Defaults to 10pc " "(for case of absolute mags)")) parser.add_argument('--phottable', type=str, default="demo_photometry.dat", help="Names of table from which to get photometry.") parser.add_argument('--objid', type=int, default=0, help="zero-index row number in the table to fit.") parser.add_argument('--mp', dest='mp', action='store_true') parser.set_defaults(mp=False) args = parser.parse_args() run_params = vars(args) obs, model, sps, noise = build_all(**run_params) run_params["sps_libraries"] = sps.ssp.libraries run_params["param_file"] = __file__ print(model) if args.mp: pool = Pool() else: pool = None if args.debug: sys.exit() #hfile = setup_h5(model=model, obs=obs, **run_params) ts = time.strftime("%y%b%d-%H.%M", time.localtime()) hfile = "{0}_{1}_result.h5".format(args.outfile, ts) output = fit_model(obs, model, sps, noise, pool=pool, **run_params) print("writing to {}".format(hfile)) writer.write_hdf5(hfile, run_params, model, obs, output["sampling"][0], output["optimization"][0], tsample=output["sampling"][1], toptimize=output["optimization"][1], sps=sps) try: hfile.close() except(AttributeError): pass