Open jekbradbury opened 4 years ago
The implementation would be something like list(jax.interpreters.pxla._thread_local_state.dynamic_axis_env.keys())
.
Regarding implementation, I think we should move the axis env to be part of the tracer state in core.py, for at least two reasons:
If we compare axis names to regular variables, so that pmap binds axis names to axes like lambda binds variable names to values, does this correspond to something sensible? “Give me all lambda-bound variables”?
Can the user just do this plumbing themselves? It might not need a jax feature.
Another weird bit: this would be like functions being able to ask “how deep in the call stack am I?”
It does sound a little weird! One use case (the one from David’s collaborator that prompted this) is to ask if e.g. there’s a “batch” axis in the environment, so that metrics-computing code can be polymorphic over whether it needs to reduce (pmean) across the batch.
It would need to be “a JAX feature” simply because I don’t think we want the axis environment itself to be seen as a public API.
I see, though the alternative to "JAX feature" I had in mind would be just for the user to keep track of all the axis names they'd passed to pmap in their own stack. Then they'd see the ones from their own pmaps, but not necessarily from others' pmaps. Also it'd keep our API smaller :P
Another weird bit: this would be like functions being able to ask “how deep in the call stack am I?”
Yeah, this worries me a bit. This feels like a feature that might be good not to have, to ensure compositionality.
The batch normalization use case is real, though. I wonder if there could be a more hygenic way to do support this, e.g., only surfacing active axis names defined by a particular library?
For example,
The semantics would be "the set of axis names currently valid inside a
psum
oraxis_index
"; the order would be the pmap nesting order. cc @dmrd