jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.12k stars 2.76k forks source link

`jit(pmap(f))` causes inefficient behavior #2926

Closed danieldjohnson closed 3 years ago

danieldjohnson commented 4 years ago

Combining jit with pmap produces some undesirable and surprising behaviors.

As one example, any lazy intermediate constants used by the function get instantiated and copied to every device. For instance:

@jax.jit
@jax.pmap
def foo(x):
  z = jnp.zeros((500_000_000,))
  return jax.lax.tie_in(z, x)

foo(jnp.arange(16).reshape((8, 2)))

This causes 2GB of data to be allocated on each device (and, right now, if this is the only computation you run, this can be verified by looking at list(list(jax.pxla.parallel_callable.__closure__[1].cell_contents.items())[0][1].values())[-1][0].__closure__[0].cell_contents, but that might break).

Relatedly, the jit causes the return value to be copied back to a single host instead of staying as a ShardedDeviceArray.

Ideally, adding jit would not make behavior worse. But having a warning when such a situation occurs would also be useful here, since pmap on its own does the right thing.

Joshuaalbert commented 3 years ago

Can the performance issues here be resolved? There are many use cases where a pmap inside of a jit is desirable. I.e. some component of a larger algorithm should be distributed, but is embedded in such a way that makes it difficult to jit-compile the independent components.

f -----> pmap(h) ----> k
  \                   /
   \------> g1 ------/
     \---->  g2 ----/
            ...

In this case, one would want to jit f, but would be forced to jit g1, ..., gn each separately.

Additionally, when one package depends on another package (e.g. how numpyro depends on jaxns), the upper level package is restricted in the way they are allowed to use jit or else has to try to get the lower level package to not use pmap.

skye commented 3 years ago

I'm not sure we can avoid the overhead in this situation, because the pmap(h) return value needs to be broadcast to all devices anyway to match the single-device execution semantics of g1 .... Do you have an example of this pattern where it's doing unnecessary device transfers?

One thing to note about this pattern in general: all of f in your example will be run on multiple devices, including g1 ..., even though the pmap(h) is the only part of f that requires multiple devices. This is because the JAX runtime system doesn't allow for dynamically changing the number of devices a single computation is run on. So to support this, we run copies of g1... on each device, then throw away all but one result.

Joshuaalbert commented 3 years ago

I didn't know about that last fact you mention. So when you jit a function, you implicitly, or explicitly, specify devices for all it's primitives to run on, and thus no primitive subset can be restricted to run on a subset of those devices. If I understand that correctly it makes sense. I would be very curious to know what the fundamental reasons are for disallowing primitives running on different devices.

Re, you question: I have an example where running pmap inside a jit is failing altogether, however I wanted to understand pmap inside a jit better before posting about it. However, perhaps this error is related the 'inefficient behaviour' mentioned in this post, since sharding is mentioned in the traceback. The traceback I'm getting is:

File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/api.py", line 398, in f_jitted
    return cpp_jitted_f(context, *args, **kwargs)
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/api.py", line 289, in cache_miss
    out_flat = xla.xla_call(
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/core.py", line 1275, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/core.py", line 1266, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/core.py", line 1278, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/core.py", line 631, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/xla.py", line 580, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/linear_util.py", line 260, in memoized_fun
    ans = call(fun, *args)
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/xla.py", line 708, in _xla_callable
    out_nodes = jaxpr_subcomp(
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/xla.py", line 452, in jaxpr_subcomp
    ans = rule(c, axis_env, extend_name_stack(name_stack, eqn.primitive.name),
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 331, in _while_loop_translation_rule
    new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env,
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/xla.py", line 452, in jaxpr_subcomp
    ans = rule(c, axis_env, extend_name_stack(name_stack, eqn.primitive.name),
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 746, in _cond_translation_rule
    branch_computations = [
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 747, in <listcomp>
    make_computation(f'branch_{i}', jaxpr, op_shape)
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 739, in make_computation
    outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/xla.py", line 460, in jaxpr_subcomp
    ans = rule(c, axis_env, in_nodes,
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1242, in _pmap_translation_rule
    outs = [_xla_unshard(c, aval, new_env, out_axis, shard, backend=backend)
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1242, in <listcomp>
    outs = [_xla_unshard(c, aval, new_env, out_axis, shard, backend=backend)
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1288, in _xla_unshard
    xla.axis_groups(axis_env, axis_env.names[-1]))
  File "/home/albert/miniconda3/envs/jax_py/lib/python3.8/site-packages/jax/interpreters/xla.py", line 512, in axis_groups
    assert not ragged
AssertionError

Process finished with exit code 1
mkunesch commented 3 years ago

Since the original issue has been addressed by #3426 and there is a warning for jit of pmap now I'll close this issue. Please feel free to reopen or file a new issue!

tavin commented 2 years ago

Why do I get this warning when I'm not using jit at all? I have a fori_loop on a function that calls the pmap'd function -- is that why?

mattjj commented 2 years ago

@tavin Yes, that's why! The control flow combinators like fori_loop basically jit themselves.

tavin commented 2 years ago

@mattjj Ok good to know! The pmap man page is quite vocal about automatically jitting, but the fori_loop man page doesn't mention it. Thanks for the confirmation.

jakevdp commented 2 years ago

Good idea - I updated some of the documentation in #10757

carlosgmartin commented 1 year ago

@Joshuaalbert I'm having the same issue, where I can't avoid pmapping inside a jit (specifically, inside a scan loop). Did you ever find a solution, by any chance?

Joshuaalbert commented 1 year ago

@carlosgmartin yes, it's possible to write JAX code that is efficient when it comes to jit(pmap). To get it to work smoothly, you should be prepared to change code structure significantly.

A pattern that might be useful is this (note this is not runnable code and should be interpreted)


# The work that will be done on each device
def inner_loop(state):
  ...

# An iterative algorithm where each iteration has two steps: 1) distribute work, 2) collect states to each device locally to do some work, e.g. determine stopping condition
def single_algorithm_thread(state):
  done = False

  while not(done):
    local_product = inner_loop(state)
    aggregated_product = all_gather(local_product, 'i') # collecting along broadcasted axis
    done = is_done(aggregated_product)
    state = make_next_state(state, local_product)

  aggregated_product = all_gather(local_product, 'i') # collecting along broadcasted axis
  return state, local_product

# Map the algorithm over devices with pmap
def step_of_larger_algorithm(state):
  parallel_algorithm = pmap(single_algorithm_thread, 'i')
  chunked_state = add_leading_dim(state)
  chunked_output, chunked_product = parallel_algorithm(chunked_state)
  output = remove_leading_dim(chunked_output)
  product = remove_leading_dim(chunked_product)
  return output, product #you may only need output, and not the product from the intermediate steps of algorithm

# JIT-compile a sequence of pmap-ed steps of large algorithm
@jit
def big_algorithm():
  state = step_of_larger_algorithm(state)
  # do something with state
  ...
  # run more steps using pmap
  state = another_step_of_larger_algorithm(state)

What is going on is that you're composing your big algorithm, that you'd like to jit-compile, into a sequence of steps where you use pmap to distribute work. Each step can collect data from all the other devices locally so that it can do something, e.g. determine a stopping condition requiring knowledge of products on all devices. This sequence can be efficiently jit compiled. There is only one important thing you need to keep in mind, which is to make the pmap'ed components stateless. Make sure that all inputs to pmap'd functions are passed in as arguments and not caught by closure. Also, make sure you try to reduce the size of objects being collected with all_gather. Sometimes you can replace an op with more efficient version, e.g. jnp.mean(all_gather(x, 'i'), axis=0) with pmean(x, 'i').

In summary, try to break up your algorithm into a sequence pmap'able steps, don't let arrays be caught from external scope, and focus on making inter device communication as light-weight as possible.

hawkinsp commented 1 year ago

We actually have a new and somewhat experimental solution to composing jit and pmap: try using jit(shmap(...)) instead. shmap isn't that well documented yet, beyond the JEP (https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html)

We are tentatively thinking we may be able to replace pmap itself with jit(shmap(...)) in the future.

carlosgmartin commented 1 year ago

@Joshuaalbert @hawkinsp Thank you for your comments. shmap looks interesting.

The structure of the program I'm dealing with is described here: https://github.com/google/jax/discussions/15693. It performs OpenAI ES, which requires only that each device send a single scalar to all other devices, on each step. If you have any specific advice for that pattern/situation, feel free to comment there. I'd really appreciate it!

As an aside, I dream of a compiler that is powerful enough to let users focus solely on the semantics of a program (what is to be computed), while the compiler figures out how to distribute the computation efficiently over a set of available resources (how it is to be computed). So no more pmap, xmap, shmap, pjit, pmean, etc. Just write your program as if it ran on a single device (i.e., specify its semantics), and let the compiler figure out the rest. As the shmap page says: Compiler, take the wheel!

Edit: Came across the following:

Alpa: Automating Inter- and Intra-Operator Parallelism for Distributed Deep Learning (repo):

Alpa is built on top of a tensor computation framework Jax. Alpa can automatically parallelize jax functions and runs them on a distributed cluster. Alpa analyses the computational graph and generates a distributed execution plan tailored for the computational graph and target cluster. The generated execution plan can combine state-of-the-art distributed training techniques including data parallelism, operator parallelism, and pipeline parallelism.

Alpa provides a simple API alpa.parallelize and automatically generates the best execution plan by solving optimization problems. Therefore, you can efficiently scale your jax computation on a distributed cluster, without any expertise in distributed computing.

Alpa provides a transformation alpa.parallelize to parallelize a jax function. alpa.parallelize is similar to jax.jit. jax.jit compiles a jax function for a single device, while alpa.parallelize compiles a jax function for a distributed device cluster. You may know that jax has some built-in transformations for parallelization, such as pmap, pjit, and xmap. However, these transformations are not fully automatic, because they require users to manually specify the parallelization strategies such as parallelization axes and device mapping schemes. You also need to manually call communication primitives such as lax.pmean and lax.all_gather, which is nontrivial if you want to do advanced model parallelization. Unlike these transformations, alpa.parallelize can do all things automatically for you. alpa.parallelize finds the best parallelization strategy for the given jax function and does the code tranformation. You only need to write the code as if you are writing for a single device.

Joshuaalbert commented 1 year ago

@hawkinsp and @carlosgmartin one thing to keep in mind is that shmap doesn't seem to have good support for non-static loops. So it's fine with scan, but not for while_loop, i.e. use pmap if you have to distribute while_loops. Is that correct @hawkinsp?