aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
https://aws.amazon.com/machine-learning/neuron/
Other
464 stars 154 forks source link

Error in attention code with various sharding configurations #1031

Open Corendos opened 1 week ago

Corendos commented 1 week ago

Summary

Hi again! I encountered a bug while playing with attention and sharding in JAX. The issue occurs with specific sharding setups and fails under certain core configurations.

Steps to Reproduce

The following code snippet can reproduce the issue:

import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P

device_count = len(jax.devices())

# Define meshes for sharding
mesh = jax.sharding.Mesh(jax.devices(), ('x'))
mesh2 = jax.sharding.Mesh(np.array(jax.devices()).reshape((device_count // 2, 2)), ('x', 'y'))

# Model and input parameters
BATCH_SIZE = 1
SEQ_LEN = 256
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 256
DIM = HEAD_DIM * NUM_HEADS
KV_DIM = HEAD_DIM * NUM_KV_HEADS
SIZE = 5000
DTYPE = jnp.bfloat16

# Sharding configuration
SHARDING_TYPE = "double"

if SHARDING_TYPE == "simple":
    input_sharding = jax.sharding.NamedSharding(mesh, P(None, 'x'))
    output_sharding = jax.sharding.NamedSharding(mesh, P('x', None))
elif SHARDING_TYPE == "double":
    input_sharding = jax.sharding.NamedSharding(mesh2, P('y', 'x'))
    output_sharding = jax.sharding.NamedSharding(mesh2, P('x', 'y'))
else:
    input_sharding = None
    output_sharding = None

# Attention function
def attention(x: jax.Array, q_proj: jax.Array, k_proj: jax.Array, v_proj: jax.Array, o_proj: jax.Array) -> jax.Array:
    q = jax.lax.dot_general(x, q_proj, (([2], [0]), ([], []))).reshape(BATCH_SIZE, SEQ_LEN, NUM_HEADS, HEAD_DIM)
    k = jax.lax.dot_general(x, k_proj, (([2], [0]), ([], []))).reshape(BATCH_SIZE, SEQ_LEN, NUM_KV_HEADS, HEAD_DIM)
    v = jax.lax.dot_general(x, v_proj, (([2], [0]), ([], []))).reshape(BATCH_SIZE, SEQ_LEN, NUM_KV_HEADS, HEAD_DIM)

    k = jnp.repeat(k, (NUM_HEADS // NUM_KV_HEADS), 2)
    v = jnp.repeat(v, (NUM_HEADS // NUM_KV_HEADS), 2)

    attn = jax.nn.dot_product_attention(q, k, v, is_causal=False).reshape(BATCH_SIZE, SEQ_LEN, DIM)
    out = jax.lax.dot_general(attn, o_proj, (([2], [0]), ([], [])))
    return out

# Test inputs
x = jnp.ones((BATCH_SIZE, SEQ_LEN, SIZE), DTYPE)
q_proj = jnp.ones((SIZE, DIM), DTYPE, device=input_sharding)
k_proj = jnp.ones((SIZE, KV_DIM), DTYPE, device=input_sharding)
v_proj = jnp.ones((SIZE, KV_DIM), DTYPE, device=input_sharding)
o_proj = jnp.ones((DIM, SIZE), DTYPE, device=output_sharding)

lowered = jax.jit(attention).lower(x, q_proj, k_proj, v_proj, o_proj)
compiled = lowered.compile()

print(f"lowered: {lowered.as_text()}")
print(f"compiled: {compiled.as_text()}")
print(compiled(x, q_proj, k_proj, v_proj, o_proj))

Observed Behavior

  1. When SHARDING_TYPE is set to "simple"

    • Run with NEURON_RT_NUM_CORES=16 python xxx.py
    • Fails with this error message:
  2. When SHARDING_TYPE is set to "double"

    • Run with NEURON_RT_NUM_CORES=16 python xxx.py
    • Also fails with this error message:
  3. When SHARDING_TYPE is "none" (sharding disabled)

    • Run with NEURON_RT_NUM_CORES=16 python xxx.py
    • The code works without errors.

Additional Observations

When I increased the number of heads to:

NUM_HEADS = 48
NUM_KV_HEADS = 12

and used NEURON_RT_NUM_CORES=24, the script worked even with sharding enabled. Interestingly, with these higher head counts, setting NEURON_RT_NUM_CORES=16 made the "simple" case work (but not the "double").

Expected Behavior

The code should work consistently across different sharding configurations and core settings.

Environment

Additional Information

Please let me know if further details are needed for reproducing or debugging the issue. Thank you!

devesr-amzn commented 1 week ago

Thank you for reporting the issue, we are able to reproduce the issue with the provided sample. We currently do not support all sharding configurations. Meshes which use non-connected devices might result in runtime failures during execution. The topology for inferentia instances can be seen here. Or by using neuron-ls --topology.

Corendos commented 1 week ago

Thank you for your answer !

So, now that we know for sure that it's not supported, I have some followup questions:

devesr-amzn commented 6 days ago

Is there a known workaround for that ? Because I fail to see what can be done. To me, sharding is quite opaque from the StableHLO perspective and the collective operations that are produced in HLO are generic.

It can be handled by using shard_map + jax.lax collectives based APIs.

Is this something that will be supported in the future ?

We will look at fixing collectives for GSPMD support, for supported mesh configurations.