patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.24k stars 63 forks source link

Support typechecking of jax.sharding.NamedSharding #164

Open reinerp opened 10 months ago

reinerp commented 10 months ago

I love jaxtyping! Can I have more of it please?

Specifically, I'd like to make assertions about the sharding of my jax.Array objects. Given an array Float[Array, "batch seqlen channel"] I'd like to assert its sharding with syntax like this: Float[ShardedArray, "batch/data_parallel seqlen channel/tensor_parallel"]. This syntax is a commonly used plain-text representation for shardings, following e.g. the notation in Figure 5 of Efficiently Scaling Transformer Inference.

The intention is that the sharding part of this syntax would this syntax would parse to a sharding spec of jax.sharding.PartitionSpec('data_parallel', None, 'tensor_parallel'). We could then assert equivalence of this partition spec against the array's actual sharding using a combination of jax.debug.inspect_array_sharding and jax.sharding.XLACompatibleSharding.is_equivalent_to.

There's a small hiccup: to convert a jax.sharding.PartitionSpec to a jax.sharding.NamedSharding, we need a jax.sharding.Mesh, which is non-constant data (contains jax "device" objects) that is undesirable to put in a type signature. I think the best user experience would be to put this in a thread-local; perhaps even the one that JAX already uses for (now-superseded) pjit: jax._src.mesh.thread_resources.env.physical_mesh (unfortunately, this is private). In that case, the sharding assertion could look like this:

import jax._src.mesh as mesh_private
import functools

def _assert_sharding_cb(ndim: int, expected: jax.sharding.XLACompatibleSharding, actual: jax.sharding.XLACompatibleSharding):
    if not expected.is_equivalent_to(actual, ndim):
      raise ValueError(f'got sharding {actual}, but expected {expected}')

def assert_sharding(v: jax.Array, expected: jax.sharding.PartitionSpec):
  mesh = mesh_private.thread_resources.env.physical_mesh
  expected_sharding = jax.sharding.NamedSharding(mesh, expected)
  jax.debug.inspect_array_sharding(v, callback=functools.partial(_assert_sharding_cb, v.ndim, expected_sharding))

Complete colab that tries this out on 8 CPUs, and shows that it works under jit too:: https://colab.research.google.com/drive/1oLy66BjKOWmh7dFu8aZbo_gBypDtlNeQ?usp=sharing.

patrick-kidger commented 10 months ago

Haha, thank you!

I like this idea, and I think the approach is using jax.debug.inspect_array_sharding is probably the correct one.

Regarding the syntax, I don't think we can use /, as that will conflict with the existing use of that for division in symbolic dimensions. (Where we evaluate the string as a python function.) I'm not sure what else is best, something that isn't normal python syntax. Maybe some kind of prefix like we have for the other kinds of special modifiers.

Regarding meshes: morally speaking, I don't think we should actually need a mesh? That is, given a sharding during lowering, can we get a PartitionSpec-like object from it, and then compare those?

reinerp commented 10 months ago

Haha, thank you!

I like this idea, and I think the approach is using jax.debug.inspect_array_sharding is probably the correct one.

Regarding the syntax, I don't think we can use /, as that will conflict with the existing use of that for division in symbolic dimensions. (Where we evaluate the string as a python function.) I'm not sure what else is best, something that isn't normal python syntax. Maybe some kind of prefix like we have for the other kinds of special modifiers.

That's a shame! Not obvious to me how a prefix would work, because we still need a separator between the dimension name and the sharding, so e.g. a prefix of ! alone isn't enough: !batch/data_parallel.

Some ideas:

In net I probably go for batch!data_parallel.

Regarding meshes: morally speaking, I don't think we should actually need a mesh? That is, given a sharding during lowering, can we get a PartitionSpec-like object from it, and then compare those?

If jax.inspect_array_sharding always returned NamedSharding, we could do that. Outside of a jit, it seems to always (in my experimentation) do so. However, inside a jit it seems to sometimes produce PositionalSharding or GSPMDSharding. I suspect this is a consequence of the sharding being determined by XLA/GSPMD's sharding propagation (which doesn't know about names) rather than being determined by Jax's sharding propagation (which does know about names).

Thus, since we can't reliably recover names, my approach was to go in the other direction: reliably map names to XLACompatibleSharding.

patrick-kidger commented 10 months ago

I like your syntax suggestions! My only suggestion is to probably switch them around: data_parallel!batch. I know this isn't the convention in the literature, but so far the convention in jaxtyping has been to put all kinds of modifiers into the prefix. That's not that a strong of a feeling on my part though, as all the current prefixes are pretty readable, but something like data_parallel!batch now has the dimension name much later.

As an alternative strategy, we could consider something like explicitly taking Float[Array, "foo bar", PartitionSpec(...)]? In particular it's nice that we can continue to use the existing PartitionSpec object, without trying to find a way to cram it into our own string DSL.


For the shardings, I'm curious what @yashk2810 thinks. (Although I don't know if he checks GitHub comments like this :) ) Given a JAX array inside of JIT, how might you assert that it matches a particular sharding specification? (Whether given by a PartitionSpec or something else.)

yashk2810 commented 10 months ago

If jax.inspect_array_sharding always returned NamedSharding, we could do that

inspect_array_sharding can return a NamedSharding if you have a with mesh context manager surrounding your jit. jit itself will ignore that but inspect_array_sharding will read it's value. If that doesn't exist, PositionalSharding is returned. GSPMDSharding is never returned.

reinerp commented 9 months ago

Given the discussion of PartitionSpec, I realized that PartitionSpec is a little more flexible than we have discussed so far. Quoting the docs: a PartitionSpec is a tuple, whose elements can be a None, a mesh axis, or a tuple of mesh axes.

The "tuple of mesh axes" part is not supported in the syntax discussion above. (It's sometimes useful to, e.g., express sharding simultaneously over the x and y axes of a TPU mesh or to express sharding of a reshaped tensor.)

The natural extension to the syntaxes above would be to support tuples via comma, e.g. foo!x,y for the tuple ('x', 'y'). This syntax doesn't allow you to distinguish 1-tuples from non-tuple values, but these also have no semantic difference in PartitionSpec, so I'd say that's desirable.

I like your syntax suggestions! My only suggestion is to probably switch them around: data_parallel!batch. I know this isn't the convention in the literature, but so far the convention in jaxtyping has been to put all kinds of modifiers into the prefix. That's not that a strong of a feeling on my part though, as all the current prefixes are pretty readable, but something like data_parallel!batch now has the dimension name much later.

Take a look which you think is more readable: 1) Prefix syntax: Float[Array, "dp!batch seqlen tp!heads d_head"] or Float[Array, "x,y!batch seqlen z!heads d_head"] 2) Suffix syntax: Float[Array, "batch!dp seqlen heads!tp d_head"] or Float[Array, "batch!x,y seqlen heads!z d_head"].

I mildly prefer (2), because putting the axis name first (before the sharding) seems to put slightly more emphasis on the axis name. Also, I think batch!x,y looks more natural than x,y!batch. But I'm fine with either.

As an alternative strategy, we could consider something like explicitly taking Float[Array, "foo bar", PartitionSpec(...)]? In particular it's nice that we can continue to use the existing PartitionSpec object, without trying to find a way to cram it into our own string DSL.

I can see the appeal of reusing the existing type!

The main disadvantages I see are: 1) Severe (to me): it's not as readable. Given Float[Array, "batch seqlen heads d_head", PartitionSpec('dp', None, 'tp', None)] I have to mentally perform a zip of the two arrays to figure out that batch is partitioned over dp and heads is partitioned over tp. Whereas with Float[Array, "batch!dp seqlen heads!tp d_head"] the zip has already been performed for me in the source code. 2) Medium (to me): it's less terse: I have to explicitly specify None for unsharded axes, and I also have various other tokens adding noise: PartitionSpec, (, ), and several , and ' characters.

For the shardings, I'm curious what @yashk2810 thinks. (Although I don't know if he checks GitHub comments like this :) ) Given a JAX array inside of JIT, how might you assert that it matches a particular sharding specification? (Whether given by a PartitionSpec or something else.)

If jax.inspect_array_sharding always returned NamedSharding, we could do that

inspect_array_sharding can return a NamedSharding if you have a with mesh context manager surrounding your jit. jit itself will ignore that but inspect_array_sharding will read it's value. If that doesn't exist, PositionalSharding is returned. GSPMDSharding is never returned.

Wonderful! Ok, let's rely on this functionality then.

patrick-kidger commented 9 months ago

Thanks @yashk2810 !

Okay so on balance, I think I'm inclined to go with the Float[Array, "foo", PartitionSpec(...)] syntax. The rationale is:

On the points you've raised:

When it comes to handling meshes, I suppose we should simply do an if not isinstance(..., NamedSharding): raise ValueError("No mesh").

Does all of the above sound reasonable to you? If so, then I'd be happy to take a pull request implementing this. :)

reinerp commented 9 months ago

Thanks Patrick. I think your reasoning is mostly valid, although I value things substantially different than you (I don't care about non-JAX support, and I suspect I care much more about actually using this feature than you do :)), which makes me land in a different place than you.

One place where I somewhat disagree with your reasoning:

The shape and the parallelism strategy are really two different things.

There's a particular (perhaps idiosyncratic to me) way of viewing things where this is not true. For way of example, let dp be the amount of data parallelism (i.e. dp is an integer), and tp be the amount of tensor parallelism (also an integer). Then for a sharding [batch/dp] seqlen [n_heads/tp] d_head, the global-view shape of the tensor is (batch, seqlen, n_heads, d_head) but the local-view shape (i.e. the shape that is seen in just one shard, e.g. via jax.experimental.shard_map) is in fact (batch/dp, seqlen, n_heads/tp, d_head). So in fact the sharding is inherently part of the per-shard shape. If you're using shard_map to go back and forth between the global view and the per-shard view, then the sharding is inherently part of the shape.

I recognize this may be a view that is somewhat idiosyncratic to users of shard_map though...


I understand you've made your decision and I'm not trying to relitigate it. I think the syntax you've proposed is workable if not (for me) perfect.

If I want a different syntax in my own codebases (where I am free from the constraints of non-JAX support, where I want to use shard_map extensively, and where I don't care about expression syntax batch-1 in jaxtyping) I can layer a better-suited-for-me syntax on top by using a simpler helper function that lowers to your syntax, e.g. a function sharded that takes e.g. sharded(Float32, '[batch/dp] seqlen [n_heads/tp] d_head') to Float32[Array, 'batch seqlen n_heads d_head', PartitionSpec(...)].

When it comes to handling meshes, I suppose we should simply do an if not isinstance(..., NamedSharding): raise ValueError("No mesh").

Sounds great.

Does all of the above sound reasonable to you? If so, then I'd be happy to take a pull request implementing this. :)

Good enough! Happy to take a stab when I get some time. Might take some time...