bd-j / prospector

Python code for Stellar Population Inference from Spectra and SEDs
http://prospect.readthedocs.io
MIT License
153 stars 71 forks source link

PR related to improvements on Issue #307 #309

Closed noahfranz13 closed 6 months ago

noahfranz13 commented 7 months ago

This is the pull request with improvements on issue #307

After digging into the multiprocessing with emcee issue further I have come to the tentative conclusion that it is not possible to speed it up to the point where it is worth using multiprocessing. With that being said, I have improved the multiprocessing with emcee so that instead of hanging for hours it finishes on the same order of magnitude as the serial implementation in case someone really wants to use multiprocessing. I also added a warning so other users see that this is a limitation and don’t go through the same process I just went through!

Some notes on what I tried:

  1. Simply using partial as is done with dynesty did not work and resulted in no speed up. This is because emcee uses multiprocessing in a way that pickles the entire likelihood function including the FSPS model. Since using partial only “stores” the FSPS model in the likelihood function it is still pickled every iteration resulting in massive slow downs.

  2. I had the best luck implementing global variables and a wrapper on the likelihood function that calls those global variables in the ensemble.py file (as is suggested in the emcee documentation). This makes it so the FSPS model is only pickled once per process when being sent to the processes so with only 2-6 processes in multiprocessing we get only slightly longer runtimes as if the code was run in serial. Any more processes and the overhead becomes too large.

In general, the overhead of pickling the FSPS model when it is sent to the multiprocessing processes seems to just be too large to make using multiprocessing with an emcee backend worth it in this case. But, I have improved the code so that it at least finishes in a comparable amount of time to the serial implementation rather than taking multiple hours when the serial implementation finishes in a few minutes. I also added the warnings in run_emcee_sampler and restart_emcee_sampler that are thrown if a user ever passes in a pool object when using the emcee backend to make sure they understand that it will probably hurt, not help, their runtime.

The solution with the global variables is not the prettiest but it does improve the usability of the code with emcee and multiprocessing. Note that I had to put a wrapper on the likelihood function in the ensemble.py file to make sure it could “see” the global variables when it is called in an individual process. I also had to redefine the pool object inside the ensemble.py file so the global variables defined in that file were sent to the processes rather than the global variables from the file where it was originally defined, presumably the parameter file.

Sadly this is as far as I got, I hope you still consider it a worthwhile improvement to the code!

bd-j commented 7 months ago

Hi @noahfranz13, thanks for digging into this, I'm sorry the outcome was not more positive.

I seem to recall long ago getting good parallelization (with MPI) by having each process define it's own lnprobfn using an sps object it builds itself. So to avoid the issue in 1. instead of partial you do something similar using globals in main


....
obs, model, gsps, noise = build_all(**run_params)
from prospect.fitting import lnprobfn
def lnprobfn_global(*args, sps=None, **kwargs):
    return lnprobfn(*args, sps=gsps, **kwargs)

if not pool.is_master():
    pool.wait()
    sys.exit(0)

nprocs = pool.size
output = fit_model(obs, model, None, noise, pool=pool, queue_size=nprocs, lnprobfn=lnprobfn_global, **run_params)

But perhaps I've broken the API that allowed it, or I'm missing something and emcee will still try to pickle gsps. Do you think that might work at least for MPI? Otherwise happy to merge this PR.

noahfranz13 commented 6 months ago

Sorry, I did not mean to close this. It looks like I could not do that with the current API because with how postkwargs was defined in fitting.run_emcee it was passing in both sps=None and sps=gsps which obviously python does not like!

However, this way of doing it with the "external" global variables definitely seems more pythonic and seems to work better with the emcee style for multiprocessing. So, I went to reset the commit history so I could change the API for this and GitHub automatically closed the issue (as it should honestly). The fix was pretty quick though so I've committed again with the fix!

Thanks for the suggestion, I'm much happier with this fix than my original commit haha!

bd-j commented 6 months ago

Great, glad it worked out. Thanks for the investigation and the PR @noahfranz13 !