jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
http://pomegranate.readthedocs.org/en/latest/
MIT License
3.29k stars 591 forks source link

BayesianNetwork: n_jobs > 1 brings no speed up, and arguments being passed incorrectly #962

Closed c-yeh closed 1 year ago

c-yeh commented 2 years ago

Hi,

First of all thanks for the great library. While applying pomegranate to my high dimensional (hundreds of states) BayesianNetwork use case, I've found that n_jobs=1 is consistently faster than n_jobs>1, and I suspect it's due to Python's GIL. In order to test my hypothesis correctly however, I faced another problem, so allow me to discuss the "sub" problem first.

If it turns out that I don't know what I'm talking about, or it's already been discussed, then please forgive my ignorance.


Test environment

EC2 Amazon Linux (m5.2xlarge) Python 3.8.8 (Cython) pomegranate 0.14.8


The Sub-Problem: Parallel calls are not passing kwargs

In the following quoted code (just 1 example out of many places), it can be seen that the call to parallel functions passes the data portion (chopped up X), but not any other argument for the parent function. This means arguments such as max_iterations are not being explicitly set for the parallel child function calls - the threads will use default values! https://github.com/jmschrei/pomegranate/blob/0652e955c400bc56df5661db3298a06854c7cce8/pomegranate/BayesianNetwork.pyx#L615-L616

This is a problem because the parallel calls (whenever n_jobs > 1) then end up running with default arguments values (i.e. max_iterations always 100) unintentionally and uncontrollably. I have confirmed this behavior by sub-classing BayesianNetwork and inserting print statements with an exact code copy of predict_proba(). I will skip those test details here unless they are required to make you believe.

Solution to sub-problem

Add any parent function arguments other than the data (X) to the f(...) call. This is because: https://github.com/jmschrei/pomegranate/blob/0652e955c400bc56df5661db3298a06854c7cce8/pomegranate/utils.pyx#L456-L460


The Main-Problem. Using 'threading' in joblib.Parallel does not bring any speed up due to Python GIL.

After fixing the above sub-problem we can now test n_jobs=1 vs. n_jobs>1 fairly. And it turned out that there is no speed up. Simplest actual timing examples:

model.predict_proba(X_test.iloc[:256].to_dict('records'), check_input=False, max_iterations=1, n_jobs=1)

is ~12.3s (5 runs average)

model.predict_proba(X_test.iloc[:256].to_dict('records'), check_input=False, max_iterations=1, n_jobs=2)

is ~13.2s (5 runs average)

model.predict_proba(X_test.iloc[:256].to_dict('records'), check_input=False, max_iterations=1, n_jobs=4)

is ~14.2s (5 runs average)

Solution to Main Problem

Use backend='multiprocessing' instead of 'threading' as suggested by joblib's documentation

Below is an example location that needs fixing. https://github.com/jmschrei/pomegranate/blob/0652e955c400bc56df5661db3298a06854c7cce8/pomegranate/BayesianNetwork.pyx#L613-L616

By changing the above to

with Parallel(n_jobs=n_jobs, backend='multiprocessing') as parallel: 
    ...

I get

model.predict_proba(X_test.iloc[:256].to_dict('records'), check_input=False, max_iterations=1, n_jobs=1)

is ~12.3s (5 runs average)

model.predict_proba(X_test.iloc[:256].to_dict('records'), check_input=False, max_iterations=1, n_jobs=2)

is ~8.4s (5 runs average)

model.predict_proba(X_test.iloc[:256].to_dict('records'), check_input=False, max_iterations=1, n_jobs=4)

is ~5.6s (5 runs average)

dolevamir commented 2 years ago

@c-yeh this looks like a really important fix to the package, i'm also encountering no improved results using the parallel jobs parameter

Is there a chance you will open a PR with your suggested fixes? I'm sure @jmschrei will approve it since it's a really important functionality that's currently missing

c-yeh commented 2 years ago

@dolevamir @jmschrei

The faq.rst claims that the GIL is released:

Does pomegranate support parallelization? Yes! pomegranate supports parallelized model fitting and model predictions, both in a data-parallel manner. Since the backend is written in cython the global interpreter lock (GIL) can be released and multi-threaded training can be supported via joblib. This means that parallelization is utilized time isn't spent piping data from one process to another nor are multiple copies of the model made.

but my results show that parallelism is not working well. Perhaps the joblib.Parallel calls decend into functions that rely on objects that require GIL, and the nogil markers are failing silently? Should this be fixed by 1) switching all threading backend to multiprocessing? 2) Inspecting GIL activity and fixing all the nogil markers?

jmschrei commented 2 years ago

Threading absolutely works with most functions. However, the predict_proba function for Bayesian networks does not use any C-level functions that can be assigned a nogil tag. This specific function should probably be changed to not use a threading backend. I am in the process of rewriting most of the core functionality of pomegranate using PyTorch, which should solve this natively, so I won't have time to hunt for these types of issues in the current repository. If you want to submit a PR fixing this specific issue, I'd accept it, though.

jmschrei commented 2 years ago

Basically, this is because I never wrote an efficient implementation of the loopy belief propagation algorithm for factor graphs. The implementation is basically just pure Python / numpy and also isn't yet speed optimized even when just considering numpy.

dolevamir commented 2 years ago

Thanks for the explanation, looking forward for the update 🙏

c-yeh commented 2 years ago

Thanks for the explanation. I will put a PR soon using a different acount.

joy13975 commented 2 years ago

@jmschrei Can you review the PR please

jmschrei commented 1 year ago

Thank you for opening an issue. pomegranate has recently been rewritten from the ground up to use PyTorch instead of Cython (v1.0.0), and so all issues are being closed as they are likely out of date. Please re-open or start a new issue if a related issue is still present in the new codebase.