optimagic-dev / optimagic

optimagic is a Python package for numerical optimization. It is a unified interface to optimizers from SciPy, NlOpt and other packages. optimagic's minimize function works just like SciPy's, so you don't have to adjust your code. You simply get more optimizers for free. On top you get diagnostic tools, parallel numerical derivatives and more.
https://optimagic.readthedocs.io/
MIT License
270 stars 30 forks source link

Multistart Parallelization with Jit Compatible Code #493

Open nikithiel opened 8 months ago

nikithiel commented 8 months ago

Hey there,

I am trying to run a gradient-based algorithm with multistart of my jit compatible code in parallel. Can I use estimagic's parallelisation using 'nprocs' via joblib or pathos or do I need to create a sample for the exploration phase manually and distribute it using jax parallelisation?

When running multistart=True with n_procs=2, I'm encountering the following warning:

RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.

If helpful, I can post a some code snippets from my implementation.

timmens commented 8 months ago

Hey @nikithiel,

Thanks for opening the issue!

Could you show me a minimal reproducible example of the behavior mentioned? And could you also post the versions of estimagic, joblib, and jax/jaxlib you use?

That would be very helpful in solving your issue!

nikithiel commented 8 months ago

Hey @timmens,

I'm using the following versions:

estimagic=0.4.6 joblib=1.3.2 jax/jaxlib=0.4.26

This warning is occuring when I'm running my code on a linux HPC, so I think it's related to that. I found two very interesting posts:

https://discuss.python.org/t/switching-default-multiprocessing-context-to-spawn-on-posix-as-well/21868/22 https://github.com/google/jax/issues/18852 --> https://github.com/google/jax/pull/18989

It seems like in multiprocessing, spawn is the default on Windows and macOS, while on linux it is fork. The latter is incompatible with multithreading (which JAX does all the time).

I'm not sure how to force joblib to use spawn tbh. Maybe by changing the backend argument in the joblib.Parallel call? Maybe this helps:

https://github.com/google/jax/issues/18852

I could also try to create an MWE. However, this is not so straighforward, as the problem occurs in a large code project.

Hope this helps, Niklas

nikithiel commented 8 months ago

I accidently closed this issue. Sorry!

hmgaudecker commented 8 months ago

It's probably just a check in Jax whether fork has been called.

Happens to me in a project with pytask-parallel and Jax recently, too.

timmens commented 7 months ago

I've found a minimal reproducible example using JAX and joblib and a way to fix it (in the MRE).

As you correctly anticipated @nikithiel , choosing a different parallelization backend fixes the MRE problem. If you want to validate that this fixes your problem, you could use a local estimagic installation to add the backend="threading" argument to the joblib batch evaluator in the batch_evaluators.py module.

Additionally, you can always run the multistart in serial using multistart_options = {"n_cores": 1}, which could already be fast enough since your objective function is multi-threaded.

[!NOTE] The following is tested on my Linux ThinkPad and might not work on your HPC machine.

@janosg, I propose we add an option to the batch_evaluator for custom kwargs and allow these to be passed through the multistart_options. What are your thoughts?

Minimal Reproducible Example

import jax.numpy as jnp
from joblib import Parallel, delayed

x_list = [jnp.ones(2) for _ in range(2)]

# Backend: loky (results in a warning)
Parallel(n_jobs=2, backend="loky")(delayed(jnp.mean)(x) for x in x_list)

# Backend: threading (does *not* result in a warning)
Parallel(n_jobs=2, backend="threading")(delayed(jnp.mean)(x) for x in x_list)

Backend: loky

Results in the warning

RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is
multithreaded, so this will likely lead to a deadlock.

Backend: threading

Results in no warning.

Versions

nikithiel commented 7 months ago

Hey @timmens,

thanks for the MRE and the suggested solution. I have changed the argument as suggested and I don't get an error on the HPC machine either. I also compared the performance of my bigger code project for a serial run, a run with backend=loky and a run with backend=threading. Interestingly, loky is 3 times faster than serial and threading is 5 times faster.

janosg commented 7 months ago

This is a very important usecase for us and we should offer a batch evaluator that support jax functions. It's not just for multistart but also for bootstrap or parallelizing optimizers. Instead of making the batch evaluator configurable with more arguments I would probably just add a new batch evaluator.

In the meantime I see two workarounds:

  1. downgrade jax. I think a year ago or so we did not have these problems
  2. disable parallelization in JAX with something like this.

Disabling JAX's default parallelism is probably a good idea anyways when you do multistart. Running multiple optimizations in parallel is a very simple and efficient form of parallelization. So as long as you have enough optimizations to keep your computer busy you probably don't want parallelize the objective function.