stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
140 stars 9 forks source link

Support for partitioning/sharded data with Pallas kernels? #72

Open G-Levine opened 4 months ago

G-Levine commented 4 months ago

I'm trying to train a model with a custom linear attention kernel I wrote in Pallas, but the following issue is occurring (only happens when the input data is sharded across multiple TPU devices).

jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Mosaic kernels cannot be automatically partitioned. Please wrap the call in a shard_map or xmap.

Here's the code that I'm trying to run: https://github.com/G-Levine/levanter/blob/9e78ab17e416d5e471f27255d13888d5fb98e632/src/levanter/models/linear_attention.py

Is there a recommended way to achieve this with Haliax? I tried to find examples of people using the Pallas FlashAttention kernel with Haliax/Levanter, but it appears nobody has tried this yet. It seems like an important use case to support, for anyone who wants to efficiently train transformer models on multiple TPUs.

dlwh commented 4 months ago

Hey,

I'm not aware of anybody trying yet. Do you have an example of it working in e.g. flax? My guess is that we need to use shardmap at they say, but I don't have much experience iwth that yet either.

dlwh commented 4 months ago

OK, the best example I see if from MaxText https://github.com/google/maxtext/blob/10a7c473e9feb1107894e7588b283b1bcfcbd679/MaxText/layers/attentions.py#L213

I think te basic idea to get a PSpec for each input array (and similarly with the the expected output shape) and then call shard_map(kernel), and then just double check that there's no sharding of axes that the kernel assumes to be single-device?

So, I think what you'll want to do is to call hax.partitioning.pspec_for_axis(a.axes, (mapping)) for every named array argument (and figure something out for non-named args I guess). Then, for each axis that needs to be on a single device (e.g. non-batch axes if there's no communication), raise an error if it's sharded in the pspec.

G-Levine commented 4 months ago

Thanks, that helps a lot. I'm able to call the kernel without errors now. However, I'm still trying to figure out how to manipulate the kernel output (a plain Jax array) back into a Haliax named array with the correct sharding. Here's my current code:

def linear_attention(
    query: NamedArray,
    key: NamedArray,
    value: NamedArray,
) -> NamedArray:
    @functools.partial(
        shard_map.shard_map,
        mesh=hax.partitioning._get_mesh(),
        in_specs=(
            hax.partitioning.pspec_for_axis(query.axes),
            hax.partitioning.pspec_for_axis(key.axes),
            hax.partitioning.pspec_for_axis(value.axes),
        ),
        out_specs=hax.partitioning.pspec_for_axis(value.axes),
        check_rep=False,
    )
    def attn_sharded(query, key, value):
        q = query.array
        k = key.array
        v = value.array
        kv_carry = jnp.zeros_like(k)
        k_carry = jnp.zeros_like(k)
        y, _, _ = attn(q, k, v, kv_carry, k_carry)
        named_y = hax.named(y, tuple((axis.name for axis in value.axes)))
        return named_y
    return attn_sharded(query, key, value)

When I try to use the output of this function in the model, it results in this error: ValueError: Shape of underlying array (256, 1024, 768) does not match shape of axes (Axis(name='batch', size=64), Axis(name='position', size=1024), Axis(name='embed', size=768)). I assume this means the sharding information was dropped somewhere (it's being sharded across 4 devices, so 64/256 of the batch axis is what's expected on one device). It's not clear to me how the MaxText example handles the output sharding (it looks like it just returns the output of the kernel directly?)

dlwh commented 4 months ago

the issue is that the output array is the "local" array inside the shard map, so Haliax infers that batch is 64, but outside the shard map the raw jax array is concatenated/global, but JAX doesn't know about Haliax's arrays so the axis sizes don't change (I should change the way Haliax works to make this easier...)

The easiest thing to do is return a plain jax array from attn_sharded and then wrap the array before returning from linear attention.

dlwh commented 4 months ago

(I'm glad this turned out to be relatively straightforward!)

G-Levine commented 4 months ago

Great, it's all working now! Here's the final code I ended up with.

def linear_attention(
    query: NamedArray,
    key: NamedArray,
    value: NamedArray,
) -> NamedArray:
    @functools.partial(
        shard_map.shard_map,
        mesh=hax.partitioning._get_mesh(),
        in_specs=(
            hax.partitioning.pspec_for_axis(query.axes),
            hax.partitioning.pspec_for_axis(key.axes),
            hax.partitioning.pspec_for_axis(value.axes),
        ),
        out_specs=hax.partitioning.pspec_for_axis(value.axes),
        check_rep=False,
    )
    def attn_sharded(query, key, value):
        q = query.array
        k = key.array
        v = value.array
        kv_carry = jnp.zeros_like(k)
        k_carry = jnp.zeros_like(k)
        y, _, _ = attn(q, k, v, kv_carry, k_carry)
        return y
    y = attn_sharded(query, key, value)
    return hax.named(y, value.axes)
dlwh commented 4 months ago

Sweet! I'll leave this open just as a "make it easy for people to do this"/make a tutorial issue.

dlwh commented 4 months ago

also, could you let me know what kind of speedup you get? We can try to prioritize getting it into Levanter if it's nontrivial

G-Levine commented 4 months ago

For the Pallas linear attention kernel? My testing so far is showing a very significant speedup across all sequence lengths. (This is the runtime for the forward + backward pass). attention_runtime

dlwh commented 4 months ago

That's nice! I actually just meant Pallas flash attention vs pure JAX attention on TPU