jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.66k stars 2.83k forks source link

API for surfacing active SPMD axis names #2911

Open jekbradbury opened 4 years ago

jekbradbury commented 4 years ago

For example,

@partial(pmap, axis_name="i")
def foo():
  @partial(pmap, axis_name="j")
  def bar():
    print(lax.axis_names())  # ("i", "j")

The semantics would be "the set of axis names currently valid inside a psum or axis_index"; the order would be the pmap nesting order. cc @dmrd

jekbradbury commented 4 years ago

The implementation would be something like list(jax.interpreters.pxla._thread_local_state.dynamic_axis_env.keys()).

mattjj commented 4 years ago

Regarding implementation, I think we should move the axis env to be part of the tracer state in core.py, for at least two reasons:

  1. it centralizes the global state, which is useful when forming thunks, and
  2. it'll let vmap easily reuse the same axis environment.
mattjj commented 4 years ago

If we compare axis names to regular variables, so that pmap binds axis names to axes like lambda binds variable names to values, does this correspond to something sensible? “Give me all lambda-bound variables”?

Can the user just do this plumbing themselves? It might not need a jax feature.

mattjj commented 4 years ago

Another weird bit: this would be like functions being able to ask “how deep in the call stack am I?”

jekbradbury commented 4 years ago

It does sound a little weird! One use case (the one from David’s collaborator that prompted this) is to ask if e.g. there’s a “batch” axis in the environment, so that metrics-computing code can be polymorphic over whether it needs to reduce (pmean) across the batch.

It would need to be “a JAX feature” simply because I don’t think we want the axis environment itself to be seen as a public API.

mattjj commented 4 years ago

I see, though the alternative to "JAX feature" I had in mind would be just for the user to keep track of all the axis names they'd passed to pmap in their own stack. Then they'd see the ones from their own pmaps, but not necessarily from others' pmaps. Also it'd keep our API smaller :P

shoyer commented 4 years ago

Another weird bit: this would be like functions being able to ask “how deep in the call stack am I?”

Yeah, this worries me a bit. This feels like a feature that might be good not to have, to ensure compositionality.

The batch normalization use case is real, though. I wonder if there could be a more hygenic way to do support this, e.g., only surfacing active axis names defined by a particular library?