Open reinerp opened 10 months ago
Haha, thank you!
I like this idea, and I think the approach is using jax.debug.inspect_array_sharding
is probably the correct one.
Regarding the syntax, I don't think we can use /
, as that will conflict with the existing use of that for division in symbolic dimensions. (Where we evaluate the string as a python function.) I'm not sure what else is best, something that isn't normal python syntax. Maybe some kind of prefix like we have for the other kinds of special modifiers.
Regarding meshes: morally speaking, I don't think we should actually need a mesh? That is, given a sharding during lowering, can we get a PartitionSpec
-like object from it, and then compare those?
Haha, thank you!
I like this idea, and I think the approach is using
jax.debug.inspect_array_sharding
is probably the correct one.Regarding the syntax, I don't think we can use
/
, as that will conflict with the existing use of that for division in symbolic dimensions. (Where we evaluate the string as a python function.) I'm not sure what else is best, something that isn't normal python syntax. Maybe some kind of prefix like we have for the other kinds of special modifiers.
That's a shame! Not obvious to me how a prefix would work, because we still need a separator between the dimension name and the sharding, so e.g. a prefix of !
alone isn't enough: !batch/data_parallel
.
Some ideas:
use a Python-invalid separator. Of the ones on my keyboard, these seem to be: !
, ~
, $
. So for example: batch!data_parallel
. The right hand side of !
must be an identifier (no more general expression syntax), so e.g. batch+4!data_parallel
parses into batch+4
and data_parallel
. Out of these separators, I visually prefer !
.
use [expr / ident]
syntax, e.g. [batch / data_parallel]
, [batch+4 / data_parallel]
, [batch/data_parallel]
. Parsing here would be: if there's []
it means this is sharding syntax. Consume the last two tokens inside the brackets, which must be an identifier (data_parallel
) and a slash (/
). Spaces are allowed but not required. Everything that remains must be the dimension. Visually I like this syntax the most, but I think it has the disadvantage that e.g. [batch-1/data_parallel]
looks like it's computing the expression batch-(1/data_parallel)
whereas it's actually computing (batch-1)
.
In net I probably go for batch!data_parallel
.
Regarding meshes: morally speaking, I don't think we should actually need a mesh? That is, given a sharding during lowering, can we get a
PartitionSpec
-like object from it, and then compare those?
If jax.inspect_array_sharding
always returned NamedSharding
, we could do that. Outside of a jit
, it seems to always (in my experimentation) do so. However, inside a jit
it seems to sometimes produce PositionalSharding
or GSPMDSharding
. I suspect this is a consequence of the sharding being determined by XLA/GSPMD's sharding propagation (which doesn't know about names) rather than being determined by Jax's sharding propagation (which does know about names).
Thus, since we can't reliably recover names, my approach was to go in the other direction: reliably map names to XLACompatibleSharding
.
I like your syntax suggestions! My only suggestion is to probably switch them around: data_parallel!batch
. I know this isn't the convention in the literature, but so far the convention in jaxtyping has been to put all kinds of modifiers into the prefix. That's not that a strong of a feeling on my part though, as all the current prefixes are pretty readable, but something like data_parallel!batch
now has the dimension name much later.
As an alternative strategy, we could consider something like explicitly taking Float[Array, "foo bar", PartitionSpec(...)]
? In particular it's nice that we can continue to use the existing PartitionSpec
object, without trying to find a way to cram it into our own string DSL.
For the shardings, I'm curious what @yashk2810 thinks. (Although I don't know if he checks GitHub comments like this :) )
Given a JAX array inside of JIT, how might you assert that it matches a particular sharding specification? (Whether given by a PartitionSpec
or something else.)
If jax.inspect_array_sharding always returned NamedSharding, we could do that
inspect_array_sharding
can return a NamedSharding if you have a with mesh
context manager surrounding your jit. jit
itself will ignore that but inspect_array_sharding
will read it's value. If that doesn't exist, PositionalSharding is returned. GSPMDSharding is never returned.
Given the discussion of PartitionSpec
, I realized that PartitionSpec
is a little more flexible than we have discussed so far. Quoting the docs: a PartitionSpec is a tuple, whose elements can be a None, a mesh axis, or a tuple of mesh axes.
The "tuple of mesh axes" part is not supported in the syntax discussion above. (It's sometimes useful to, e.g., express sharding simultaneously over the x and y axes of a TPU mesh or to express sharding of a reshaped tensor.)
The natural extension to the syntaxes above would be to support tuples via comma, e.g. foo!x,y
for the tuple ('x', 'y')
. This syntax doesn't allow you to distinguish 1-tuples from non-tuple values, but these also have no semantic difference in PartitionSpec, so I'd say that's desirable.
I like your syntax suggestions! My only suggestion is to probably switch them around:
data_parallel!batch
. I know this isn't the convention in the literature, but so far the convention in jaxtyping has been to put all kinds of modifiers into the prefix. That's not that a strong of a feeling on my part though, as all the current prefixes are pretty readable, but something likedata_parallel!batch
now has the dimension name much later.
Take a look which you think is more readable:
1) Prefix syntax: Float[Array, "dp!batch seqlen tp!heads d_head"]
or Float[Array, "x,y!batch seqlen z!heads d_head"]
2) Suffix syntax: Float[Array, "batch!dp seqlen heads!tp d_head"]
or Float[Array, "batch!x,y seqlen heads!z d_head"]
.
I mildly prefer (2), because putting the axis name first (before the sharding) seems to put slightly more emphasis on the axis name. Also, I think batch!x,y
looks more natural than x,y!batch
. But I'm fine with either.
As an alternative strategy, we could consider something like explicitly taking
Float[Array, "foo bar", PartitionSpec(...)]
? In particular it's nice that we can continue to use the existingPartitionSpec
object, without trying to find a way to cram it into our own string DSL.
I can see the appeal of reusing the existing type!
The main disadvantages I see are:
1) Severe (to me): it's not as readable. Given Float[Array, "batch seqlen heads d_head", PartitionSpec('dp', None, 'tp', None)]
I have to mentally perform a zip
of the two arrays to figure out that batch
is partitioned over dp
and heads
is partitioned over tp
. Whereas with Float[Array, "batch!dp seqlen heads!tp d_head"]
the zip has already been performed for me in the source code.
2) Medium (to me): it's less terse: I have to explicitly specify None
for unsharded axes, and I also have various other tokens adding noise: PartitionSpec
, (
, )
, and several ,
and '
characters.
For the shardings, I'm curious what @yashk2810 thinks. (Although I don't know if he checks GitHub comments like this :) ) Given a JAX array inside of JIT, how might you assert that it matches a particular sharding specification? (Whether given by a
PartitionSpec
or something else.)If jax.inspect_array_sharding always returned NamedSharding, we could do that
inspect_array_sharding
can return a NamedSharding if you have awith mesh
context manager surrounding your jit.jit
itself will ignore that butinspect_array_sharding
will read it's value. If that doesn't exist, PositionalSharding is returned. GSPMDSharding is never returned.
Wonderful! Ok, let's rely on this functionality then.
Thanks @yashk2810 !
Okay so on balance, I think I'm inclined to go with the Float[Array, "foo", PartitionSpec(...)]
syntax. The rationale is:
Float[Array, "foo", SomeOtherStrategyFromAnotherLibrary(...)]
without having to worry about compatibility with the shape syntax.replicated = PartitionSpec(None); Float[Array, "foo", replicated]
and programmatically re-use parallelism strategies.On the points you've raised:
PartitionSpec
as just P
, so I think this should still be pretty concise.When it comes to handling meshes, I suppose we should simply do an if not isinstance(..., NamedSharding): raise ValueError("No mesh")
.
Does all of the above sound reasonable to you? If so, then I'd be happy to take a pull request implementing this. :)
Thanks Patrick. I think your reasoning is mostly valid, although I value things substantially different than you (I don't care about non-JAX support, and I suspect I care much more about actually using this feature than you do :)), which makes me land in a different place than you.
One place where I somewhat disagree with your reasoning:
The shape and the parallelism strategy are really two different things.
There's a particular (perhaps idiosyncratic to me) way of viewing things where this is not true. For way of example, let dp
be the amount of data parallelism (i.e. dp
is an integer), and tp
be the amount of tensor parallelism (also an integer). Then for a sharding [batch/dp] seqlen [n_heads/tp] d_head
, the global-view shape of the tensor is (batch, seqlen, n_heads, d_head)
but the local-view shape (i.e. the shape that is seen in just one shard, e.g. via jax.experimental.shard_map
) is in fact (batch/dp, seqlen, n_heads/tp, d_head)
. So in fact the sharding is inherently part of the per-shard shape. If you're using shard_map
to go back and forth between the global view and the per-shard view, then the sharding is inherently part of the shape.
I recognize this may be a view that is somewhat idiosyncratic to users of shard_map
though...
I understand you've made your decision and I'm not trying to relitigate it. I think the syntax you've proposed is workable if not (for me) perfect.
If I want a different syntax in my own codebases (where I am free from the constraints of non-JAX support, where I want to use shard_map
extensively, and where I don't care about expression syntax batch-1
in jaxtyping) I can layer a better-suited-for-me syntax on top by using a simpler helper function that lowers to your syntax, e.g. a function sharded
that takes e.g. sharded(Float32, '[batch/dp] seqlen [n_heads/tp] d_head')
to Float32[Array, 'batch seqlen n_heads d_head', PartitionSpec(...)]
.
When it comes to handling meshes, I suppose we should simply do an
if not isinstance(..., NamedSharding): raise ValueError("No mesh")
.
Sounds great.
Does all of the above sound reasonable to you? If so, then I'd be happy to take a pull request implementing this. :)
Good enough! Happy to take a stab when I get some time. Might take some time...
I love jaxtyping! Can I have more of it please?
Specifically, I'd like to make assertions about the sharding of my
jax.Array
objects. Given an arrayFloat[Array, "batch seqlen channel"]
I'd like to assert its sharding with syntax like this:Float[ShardedArray, "batch/data_parallel seqlen channel/tensor_parallel"]
. This syntax is a commonly used plain-text representation for shardings, following e.g. the notation in Figure 5 of Efficiently Scaling Transformer Inference.The intention is that the sharding part of this syntax would this syntax would parse to a sharding spec of
jax.sharding.PartitionSpec('data_parallel', None, 'tensor_parallel')
. We could then assert equivalence of this partition spec against the array's actual sharding using a combination ofjax.debug.inspect_array_sharding
andjax.sharding.XLACompatibleSharding.is_equivalent_to
.There's a small hiccup: to convert a
jax.sharding.PartitionSpec
to ajax.sharding.NamedSharding
, we need ajax.sharding.Mesh
, which is non-constant data (contains jax "device" objects) that is undesirable to put in a type signature. I think the best user experience would be to put this in a thread-local; perhaps even the one that JAX already uses for (now-superseded) pjit:jax._src.mesh.thread_resources.env.physical_mesh
(unfortunately, this is private). In that case, the sharding assertion could look like this:Complete colab that tries this out on 8 CPUs, and shows that it works under
jit
too:: https://colab.research.google.com/drive/1oLy66BjKOWmh7dFu8aZbo_gBypDtlNeQ?usp=sharing.