Open nikithiel opened 8 months ago
Hey @nikithiel,
Thanks for opening the issue!
Could you show me a minimal reproducible example of the behavior mentioned? And could you also post the versions of estimagic, joblib, and jax/jaxlib you use?
That would be very helpful in solving your issue!
Hey @timmens,
I'm using the following versions:
estimagic=0.4.6 joblib=1.3.2 jax/jaxlib=0.4.26
This warning is occuring when I'm running my code on a linux HPC, so I think it's related to that. I found two very interesting posts:
https://discuss.python.org/t/switching-default-multiprocessing-context-to-spawn-on-posix-as-well/21868/22 https://github.com/google/jax/issues/18852 --> https://github.com/google/jax/pull/18989
It seems like in multiprocessing
, spawn
is the default on Windows and macOS, while on linux it is fork
. The latter is incompatible with multithreading (which JAX does all the time).
I'm not sure how to force joblib to use spawn
tbh. Maybe by changing the backend
argument in the joblib.Parallel
call? Maybe this helps:
https://github.com/google/jax/issues/18852
I could also try to create an MWE. However, this is not so straighforward, as the problem occurs in a large code project.
Hope this helps, Niklas
I accidently closed this issue. Sorry!
It's probably just a check in Jax whether fork has been called.
Happens to me in a project with pytask-parallel and Jax recently, too.
I've found a minimal reproducible example using JAX and joblib and a way to fix it (in the MRE).
As you correctly anticipated @nikithiel , choosing a different parallelization backend fixes the MRE problem. If you want to validate that this fixes your problem, you could use a local estimagic installation to add the backend="threading"
argument to the joblib batch evaluator in the batch_evaluators.py
module.
Additionally, you can always run the multistart in serial using multistart_options = {"n_cores": 1}
, which could already be fast enough since your objective function is multi-threaded.
[!NOTE] The following is tested on my Linux ThinkPad and might not work on your HPC machine.
@janosg, I propose we add an option to the batch_evaluator for custom kwargs and allow these to be passed through the multistart_options. What are your thoughts?
import jax.numpy as jnp
from joblib import Parallel, delayed
x_list = [jnp.ones(2) for _ in range(2)]
# Backend: loky (results in a warning)
Parallel(n_jobs=2, backend="loky")(delayed(jnp.mean)(x) for x in x_list)
# Backend: threading (does *not* result in a warning)
Parallel(n_jobs=2, backend="threading")(delayed(jnp.mean)(x) for x in x_list)
Results in the warning
RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is
multithreaded, so this will likely lead to a deadlock.
Results in no warning.
Hey @timmens,
thanks for the MRE and the suggested solution. I have changed the argument as suggested and I don't get an error on the HPC machine either. I also compared the performance of my bigger code project for a serial run, a run with backend=loky
and a run with backend=threading
. Interestingly, loky
is 3 times faster than serial and threading
is 5 times faster.
This is a very important usecase for us and we should offer a batch evaluator that support jax functions. It's not just for multistart but also for bootstrap or parallelizing optimizers. Instead of making the batch evaluator configurable with more arguments I would probably just add a new batch evaluator.
In the meantime I see two workarounds:
Disabling JAX's default parallelism is probably a good idea anyways when you do multistart. Running multiple optimizations in parallel is a very simple and efficient form of parallelization. So as long as you have enough optimizations to keep your computer busy you probably don't want parallelize the objective function.
Hey there,
I am trying to run a gradient-based algorithm with multistart of my jit compatible code in parallel. Can I use estimagic's parallelisation using 'nprocs' via
joblib
orpathos
or do I need to create asample
for the exploration phase manually and distribute it using jax parallelisation?When running
multistart=True
withn_procs=2
, I'm encountering the following warning:RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
If helpful, I can post a some code snippets from my implementation.