google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.76k stars 2.72k forks source link

Clarification on which objects are multiprocess/fork safe? #3691

Open proteneer opened 4 years ago

proteneer commented 4 years ago

This is likely a dumb/obvious question but is there any guideline on multiprocess/fork safety of various jax objects?

In particular I've noticed code hanging when I accidentally serialized a jax xla devicearray as opposed to a numpy ndarray via python's multiprocessing module.

I'd also be okay with a firm "please don't serialize jax objects and always convert to np ndarray".

hawkinsp commented 4 years ago

In general no JAX objects are safe in the presence of fork(). fork() is fundamentally incompatible with threading, and JAX uses threading internally. If you want to fork() and do anything other than something like exec(), do the fork() before calling into JAX.

I think the current state is "most bets are off if you serialize JAX objects". That's not because we're opposed, but more because it's not something we test and so there are no guarantees. By contrast np.ndarray will certainly work. Does that answer the question?

proteneer commented 4 years ago

Thanks for the clarification, that's inline with my expectation. Would be in reasonable to throw an exception upon calling serialization methods (eg. throw in setstate and getstate) for internal jax objects as opposed to having it be silent serialized?