aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
https://aws.amazon.com/machine-learning/neuron/
Other
464 stars 154 forks source link

Compiler crash when sharding model weights #1030

Open Corendos opened 1 week ago

Corendos commented 1 week ago

Hi !

I was playing with JAX on Neuron recently and came across a bug that is quite annoying.

When trying to shard a very simple MLP layer, depending on the axis you choose, the compilation fails.

Here is a snippet of code that demonstrates the issue:

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P

mesh = jax.sharding.Mesh(jax.devices(), ('x'))

BATCH_SIZE = 1
SIZE = 4096
HIDDEN_SIZE = 8192
DTYPE = jnp.bfloat16
SHARD_ON_HIDDEN = True

def mlp(x: jax.Array, gate_up_proj: jax.Array, down_proj: jax.Array) -> jax.Array:
    # x: (B, D)
    # gate_up_proj: (D, 2 * H)
    # down_proj: (H, D)
    hidden = jax.lax.dot_general(x, gate_up_proj, (([1], [0]), ([],[])))
    # hidden: (B, 2 * H)
    x1, x2 = jnp.split(hidden, 2, 1)
    hidden = jax.nn.gelu(x1) * x2
    # hidden: (B, H)
    hidden = jax.lax.dot_general(hidden, down_proj, (([1], [1]), ([],[])))
    return hidden

if SHARD_ON_HIDDEN:
    weight_sharding = jax.sharding.NamedSharding(mesh, P(None, 'x'))
else:
    weight_sharding = jax.sharding.NamedSharding(mesh, P('x', None))

x = jax.ShapeDtypeStruct(shape=(BATCH_SIZE, SIZE), dtype=DTYPE, sharding=jax.sharding.NamedSharding(mesh, P(None, None)))
gate_up_proj = jax.ShapeDtypeStruct((SIZE, 2 * HIDDEN_SIZE), dtype=DTYPE, sharding=weight_sharding)
down_proj = jax.ShapeDtypeStruct((SIZE, HIDDEN_SIZE), dtype=DTYPE, sharding=weight_sharding)

lowered = jax.jit(mlp).lower(x, gate_up_proj, down_proj)
print(lowered.as_text())

compiled = lowered.compile()
print(compiled.as_text())

If SHARD_ON_HIDDEN is set to False everything works fine but when it's set to True (which is something you want for optimal performance), it crashes when using more than 2 Neuron cores (probably because it then introduces interconnect transfers).

Here is the error:

2024-Nov-09 16:14:06.605467 66954:67403 ERROR   ENC:enc_parse_replica_groups                [nec_dev 2] replica groups (0/1) does not have myself 2
2024-Nov-09 16:14:06.605522 66954:67403 ERROR  TDRV:build_enc_source_target_pairs           Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.605546 66954:67403 ERROR  TDRV:instr_col_translate_ptc2                Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.605670 66954:67403 ERROR  TDRV:ib_create_one_block                     failed to translate instructions
2024-Nov-09 16:14:06.605695 66954:67403 ERROR  TDRV:ib_create_eib_v2                        Failed to create instruction blocks
2024-Nov-09 16:14:06.605715 66954:67403 ERROR  TDRV:sequencer_v2_setup_instr_one_eng        Failed to allocate eib
2024-Nov-09 16:14:06.605735 66954:67403 ERROR  TDRV:kbl_model_add                           create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.605981 66954:67403 ERROR  NMGR:dlr_kelf_stage                          Failed to load subgraph
2024-Nov-09 16:14:06.606000 66954:67403 ERROR  NMGR:kmgr_load_nn_internal_v2                Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.606012 66954:67403 ERROR  NMGR:kmgr_load_nn_post_metrics               Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4
2024-Nov-09 16:14:06.606025 66954:67403 ERROR   NRT:nrt_infodump                            Neuron runtime information - please include in any support request:
2024-Nov-09 16:14:06.606040 66954:67403 ERROR   NRT:nrt_infodump                            ------------->8------------[ cut here ]------------>8-------------
2024-Nov-09 16:14:06.606062 66954:67403 ERROR   NRT:nrt_infodump                            NRT version: 2.22.14.0 (6e27b8d5b22dea0e0b8375517f4d8a009b6de5a8)
2024-Nov-09 16:14:06.606084 66954:67403 ERROR   NRT:nrt_infodump                            Embedded FW version: 1.12.2.0 (f152b70c827a52701d6b9ee74ec7ff7a15971f7d)
2024-Nov-09 16:14:06.606112 66954:67403 ERROR   NRT:nrt_infodump                            CCOM version: 2.22.26.0- (compat 48)
2024-Nov-09 16:14:06.606134 66954:67403 ERROR   NRT:nrt_infodump                            Instance ID: i-0b19e4a1cf3fd70d9
2024-Nov-09 16:14:06.606156 66954:67403 ERROR   NRT:nrt_infodump                            Cluster ID: N/A
2024-Nov-09 16:14:06.606178 66954:67403 ERROR   NRT:nrt_infodump                            Kernel: Linux 6.8.0-1015-aws #16~22.04.1-Ubuntu SMP Mon Aug 19 19:38:17 UTC 2024
2024-Nov-09 16:14:06.606200 66954:67403 ERROR   NRT:nrt_infodump                            Nodename: ip-172-31-42-39
2024-Nov-09 16:14:06.606254 66954:67403 ERROR   NRT:nrt_infodump                            Driver version: 2.18.12.0

2024-Nov-09 16:14:06.606276 66954:67403 ERROR   NRT:nrt_infodump                            Failure: NRT_RESOURCE in nrt_load()
2024-Nov-09 16:14:06.606298 66954:67403 ERROR   NRT:nrt_infodump                            Visible cores: 0, 1, 2, 3
2024-Nov-09 16:14:06.606318 66954:67403 ERROR   NRT:nrt_infodump                            Environment:
2024-Nov-09 16:14:06.606341 66954:67403 ERROR   NRT:nrt_infodump                                NEURON_CC_FLAGS=--model-type=transformer --auto-cast=none
2024-Nov-09 16:14:06.606362 66954:67403 ERROR   NRT:nrt_infodump                                NEURON_RT_NUM_CORES=4
2024-Nov-09 16:14:06.606382 66954:67403 ERROR   NRT:nrt_infodump                                NEURON_RT_ROOT_COMM_ID=localhost:49255
2024-Nov-09 16:14:06.606401 66954:67403 ERROR   NRT:nrt_infodump                            -------------8<-----------[ cut to here ]-----------8<------------
2024-Nov-09 16:14:06.602248 66954:67406 ERROR   ENC:enc_parse_replica_groups                [nec_dev 1] replica groups (0/1) does not have myself 1
2024-Nov-09 16:14:06.603586 66954:67404 ERROR   ENC:enc_parse_replica_groups                [nec_dev 3] replica groups (0/1) does not have myself 3
2024-Nov-09 16:14:06.614656 66954:67406 ERROR  TDRV:build_enc_source_target_pairs           Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.626936 66954:67404 ERROR  TDRV:build_enc_source_target_pairs           Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.639079 66954:67406 ERROR  TDRV:instr_col_translate_ptc2                Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.649859 66954:67404 ERROR  TDRV:instr_col_translate_ptc2                Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.662122 66954:67406 ERROR  TDRV:ib_create_one_block                     failed to translate instructions
2024-Nov-09 16:14:06.672477 66954:67404 ERROR  TDRV:ib_create_one_block                     failed to translate instructions
2024-Nov-09 16:14:06.682738 66954:67406 ERROR  TDRV:ib_create_eib_v2                        Failed to create instruction blocks
2024-Nov-09 16:14:06.692175 66954:67404 ERROR  TDRV:ib_create_eib_v2                        Failed to create instruction blocks
2024-Nov-09 16:14:06.702800 66954:67406 ERROR  TDRV:sequencer_v2_setup_instr_one_eng        Failed to allocate eib
2024-Nov-09 16:14:06.712256 66954:67404 ERROR  TDRV:sequencer_v2_setup_instr_one_eng        Failed to allocate eib
2024-Nov-09 16:14:06.723884 66954:67406 ERROR  TDRV:kbl_model_add                           create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.751248 66954:67405 ERROR   ENC:enc_parse_replica_groups                [nec_dev 0] replica groups (0/1) does not have myself 0
2024-Nov-09 16:14:06.751330 66954:67405 ERROR  TDRV:build_enc_source_target_pairs           Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.751345 66954:67405 ERROR  TDRV:instr_col_translate_ptc2                Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.751929 66954:67405 ERROR  TDRV:ib_create_one_block                     failed to translate instructions
2024-Nov-09 16:14:06.751957 66954:67405 ERROR  TDRV:ib_create_eib_v2                        Failed to create instruction blocks
2024-Nov-09 16:14:06.751974 66954:67405 ERROR  TDRV:sequencer_v2_setup_instr_one_eng        Failed to allocate eib
2024-Nov-09 16:14:06.751990 66954:67405 ERROR  TDRV:kbl_model_add                           create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.753195 66954:67405 ERROR  NMGR:dlr_kelf_stage                          Failed to load subgraph
2024-Nov-09 16:14:06.753230 66954:67405 ERROR  NMGR:kmgr_load_nn_internal_v2                Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.753248 66954:67405 ERROR  NMGR:kmgr_load_nn_post_metrics               Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4
2024-Nov-09 16:14:06.735901 66954:67404 ERROR  TDRV:kbl_model_add                           create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.749448 66954:67406 ERROR  NMGR:dlr_kelf_stage                          Failed to load subgraph
2024-Nov-09 16:14:06.762705 66954:67404 ERROR  NMGR:dlr_kelf_stage                          Failed to load subgraph
2024-Nov-09 16:14:06.775586 66954:67406 ERROR  NMGR:kmgr_load_nn_internal_v2                Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.789108 66954:67404 ERROR  NMGR:kmgr_load_nn_internal_v2                Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.799686 66954:67406 ERROR  NMGR:kmgr_load_nn_post_metrics               Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4
2024-Nov-09 16:14:06.810033 66954:67404 ERROR  NMGR:kmgr_load_nn_post_metrics               Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4

Environment

Thanks for the help !

AWSNB commented 1 week ago

Good catch, we’ll get back to you soon

From: Corentin Godeau @.> Reply-To: aws-neuron/aws-neuron-sdk @.> Date: Saturday, November 9, 2024 at 8:22 AM To: aws-neuron/aws-neuron-sdk @.> Cc: Subscribed @.> Subject: [aws-neuron/aws-neuron-sdk] Compiler crash when sharding model weights (Issue #1030)

Hi !

I was playing with JAX on Neuron recently and came across a bug that is quite annoying.

When trying to shard a very simple MLP layer, depending on the axis you choose, the compilation fails.

Here is a snippet of code that demonstrates the issue:

import jax

import jax.numpy as jnp

from jax.sharding import PartitionSpec as P

mesh = jax.sharding.Mesh(jax.devices(), ('x'))

BATCH_SIZE = 1

SIZE = 4096

HIDDEN_SIZE = 8192

DTYPE = jnp.bfloat16

SHARD_ON_HIDDEN = True

def mlp(x: jax.Array, gate_up_proj: jax.Array, down_proj: jax.Array) -> jax.Array:

# x: (B, D)

# gate_up_proj: (D, 2 * H)

# down_proj: (H, D)

hidden = jax.lax.dot_general(x, gate_up_proj, (([1], [0]), ([],[])))

# hidden: (B, 2 * H)

x1, x2 = jnp.split(hidden, 2, 1)

hidden = jax.nn.gelu(x1) * x2

# hidden: (B, H)

hidden = jax.lax.dot_general(hidden, down_proj, (([1], [1]), ([],[])))

return hidden

if SHARD_ON_HIDDEN:

weight_sharding = jax.sharding.NamedSharding(mesh, P(None, 'x'))

else:

weight_sharding = jax.sharding.NamedSharding(mesh, P('x', None))

x = jax.ShapeDtypeStruct(shape=(BATCH_SIZE, SIZE), dtype=DTYPE, sharding=jax.sharding.NamedSharding(mesh, P(None, None)))

gate_up_proj = jax.ShapeDtypeStruct((SIZE, 2 * HIDDEN_SIZE), dtype=DTYPE, sharding=weight_sharding)

down_proj = jax.ShapeDtypeStruct((SIZE, HIDDEN_SIZE), dtype=DTYPE, sharding=weight_sharding)

lowered = jax.jit(mlp).lower(x, gate_up_proj, down_proj)

print(lowered.as_text())

compiled = lowered.compile()

print(compiled.as_text())

If SHARD_ON_HIDDEN is set to False everything works fine but when it's set to True (which is something you want for optimal performance), it crashes when using more than 2 Neuron cores (probably because it then introduces interconnect transfers).

Here is the error:

2024-Nov-09 16:14:06.605467 66954:67403 ERROR ENC:enc_parse_replica_groups [nec_dev 2] replica groups (0/1) does not have myself 2

2024-Nov-09 16:14:06.605522 66954:67403 ERROR TDRV:build_enc_source_target_pairs Failed to parse src_target_pairs on ENC

2024-Nov-09 16:14:06.605546 66954:67403 ERROR TDRV:instr_col_translate_ptc2 Failed to setup pseudo trigger collectives2 instruction

2024-Nov-09 16:14:06.605670 66954:67403 ERROR TDRV:ib_create_one_block failed to translate instructions

2024-Nov-09 16:14:06.605695 66954:67403 ERROR TDRV:ib_create_eib_v2 Failed to create instruction blocks

2024-Nov-09 16:14:06.605715 66954:67403 ERROR TDRV:sequencer_v2_setup_instr_one_eng Failed to allocate eib

2024-Nov-09 16:14:06.605735 66954:67403 ERROR TDRV:kbl_model_add create_engine_refill_rings_v1() error

2024-Nov-09 16:14:06.605981 66954:67403 ERROR NMGR:dlr_kelf_stage Failed to load subgraph

2024-Nov-09 16:14:06.606000 66954:67403 ERROR NMGR:kmgr_load_nn_internal_v2 Failed to stage graph: kelf-0.json to NeuronCore

2024-Nov-09 16:14:06.606012 66954:67403 ERROR NMGR:kmgr_load_nn_post_metrics Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4

2024-Nov-09 16:14:06.606025 66954:67403 ERROR NRT:nrt_infodump Neuron runtime information - please include in any support request:

2024-Nov-09 16:14:06.606040 66954:67403 ERROR NRT:nrt_infodump ------------->8------------[ cut here ]------------>8-------------

2024-Nov-09 16:14:06.606062 66954:67403 ERROR NRT:nrt_infodump NRT version: 2.22.14.0 (6e27b8d5b22dea0e0b8375517f4d8a009b6de5a8)

2024-Nov-09 16:14:06.606084 66954:67403 ERROR NRT:nrt_infodump Embedded FW version: 1.12.2.0 (f152b70c827a52701d6b9ee74ec7ff7a15971f7d)

2024-Nov-09 16:14:06.606112 66954:67403 ERROR NRT:nrt_infodump CCOM version: 2.22.26.0- (compat 48)

2024-Nov-09 16:14:06.606134 66954:67403 ERROR NRT:nrt_infodump Instance ID: i-0b19e4a1cf3fd70d9

2024-Nov-09 16:14:06.606156 66954:67403 ERROR NRT:nrt_infodump Cluster ID: N/A

2024-Nov-09 16:14:06.606178 66954:67403 ERROR NRT:nrt_infodump Kernel: Linux 6.8.0-1015-aws #16~22.04.1-Ubuntu SMP Mon Aug 19 19:38:17 UTC 2024

2024-Nov-09 16:14:06.606200 66954:67403 ERROR NRT:nrt_infodump Nodename: ip-172-31-42-39

2024-Nov-09 16:14:06.606254 66954:67403 ERROR NRT:nrt_infodump Driver version: 2.18.12.0

2024-Nov-09 16:14:06.606276 66954:67403 ERROR NRT:nrt_infodump Failure: NRT_RESOURCE in nrt_load()

2024-Nov-09 16:14:06.606298 66954:67403 ERROR NRT:nrt_infodump Visible cores: 0, 1, 2, 3

2024-Nov-09 16:14:06.606318 66954:67403 ERROR NRT:nrt_infodump Environment:

2024-Nov-09 16:14:06.606341 66954:67403 ERROR NRT:nrt_infodump NEURON_CC_FLAGS=--model-type=transformer --auto-cast=none

2024-Nov-09 16:14:06.606362 66954:67403 ERROR NRT:nrt_infodump NEURON_RT_NUM_CORES=4

2024-Nov-09 16:14:06.606382 66954:67403 ERROR NRT:nrt_infodump NEURON_RT_ROOT_COMM_ID=localhost:49255

2024-Nov-09 16:14:06.606401 66954:67403 ERROR NRT:nrt_infodump -------------8<-----------[ cut to here ]-----------8<------------

2024-Nov-09 16:14:06.602248 66954:67406 ERROR ENC:enc_parse_replica_groups [nec_dev 1] replica groups (0/1) does not have myself 1

2024-Nov-09 16:14:06.603586 66954:67404 ERROR ENC:enc_parse_replica_groups [nec_dev 3] replica groups (0/1) does not have myself 3

2024-Nov-09 16:14:06.614656 66954:67406 ERROR TDRV:build_enc_source_target_pairs Failed to parse src_target_pairs on ENC

2024-Nov-09 16:14:06.626936 66954:67404 ERROR TDRV:build_enc_source_target_pairs Failed to parse src_target_pairs on ENC

2024-Nov-09 16:14:06.639079 66954:67406 ERROR TDRV:instr_col_translate_ptc2 Failed to setup pseudo trigger collectives2 instruction

2024-Nov-09 16:14:06.649859 66954:67404 ERROR TDRV:instr_col_translate_ptc2 Failed to setup pseudo trigger collectives2 instruction

2024-Nov-09 16:14:06.662122 66954:67406 ERROR TDRV:ib_create_one_block failed to translate instructions

2024-Nov-09 16:14:06.672477 66954:67404 ERROR TDRV:ib_create_one_block failed to translate instructions

2024-Nov-09 16:14:06.682738 66954:67406 ERROR TDRV:ib_create_eib_v2 Failed to create instruction blocks

2024-Nov-09 16:14:06.692175 66954:67404 ERROR TDRV:ib_create_eib_v2 Failed to create instruction blocks

2024-Nov-09 16:14:06.702800 66954:67406 ERROR TDRV:sequencer_v2_setup_instr_one_eng Failed to allocate eib

2024-Nov-09 16:14:06.712256 66954:67404 ERROR TDRV:sequencer_v2_setup_instr_one_eng Failed to allocate eib

2024-Nov-09 16:14:06.723884 66954:67406 ERROR TDRV:kbl_model_add create_engine_refill_rings_v1() error

2024-Nov-09 16:14:06.751248 66954:67405 ERROR ENC:enc_parse_replica_groups [nec_dev 0] replica groups (0/1) does not have myself 0

2024-Nov-09 16:14:06.751330 66954:67405 ERROR TDRV:build_enc_source_target_pairs Failed to parse src_target_pairs on ENC

2024-Nov-09 16:14:06.751345 66954:67405 ERROR TDRV:instr_col_translate_ptc2 Failed to setup pseudo trigger collectives2 instruction

2024-Nov-09 16:14:06.751929 66954:67405 ERROR TDRV:ib_create_one_block failed to translate instructions

2024-Nov-09 16:14:06.751957 66954:67405 ERROR TDRV:ib_create_eib_v2 Failed to create instruction blocks

2024-Nov-09 16:14:06.751974 66954:67405 ERROR TDRV:sequencer_v2_setup_instr_one_eng Failed to allocate eib

2024-Nov-09 16:14:06.751990 66954:67405 ERROR TDRV:kbl_model_add create_engine_refill_rings_v1() error

2024-Nov-09 16:14:06.753195 66954:67405 ERROR NMGR:dlr_kelf_stage Failed to load subgraph

2024-Nov-09 16:14:06.753230 66954:67405 ERROR NMGR:kmgr_load_nn_internal_v2 Failed to stage graph: kelf-0.json to NeuronCore

2024-Nov-09 16:14:06.753248 66954:67405 ERROR NMGR:kmgr_load_nn_post_metrics Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4

2024-Nov-09 16:14:06.735901 66954:67404 ERROR TDRV:kbl_model_add create_engine_refill_rings_v1() error

2024-Nov-09 16:14:06.749448 66954:67406 ERROR NMGR:dlr_kelf_stage Failed to load subgraph

2024-Nov-09 16:14:06.762705 66954:67404 ERROR NMGR:dlr_kelf_stage Failed to load subgraph

2024-Nov-09 16:14:06.775586 66954:67406 ERROR NMGR:kmgr_load_nn_internal_v2 Failed to stage graph: kelf-0.json to NeuronCore

2024-Nov-09 16:14:06.789108 66954:67404 ERROR NMGR:kmgr_load_nn_internal_v2 Failed to stage graph: kelf-0.json to NeuronCore

2024-Nov-09 16:14:06.799686 66954:67406 ERROR NMGR:kmgr_load_nn_post_metrics Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4

2024-Nov-09 16:14:06.810033 66954:67404 ERROR NMGR:kmgr_load_nn_post_metrics Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4

Environment

Thanks for the help !

— Reply to this email directly, view it on GitHubhttps://github.com/aws-neuron/aws-neuron-sdk/issues/1030, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AFTRWCOMYDUQSVYAGLEL2GDZ7YZDHAVCNFSM6AAAAABRPIM2U6VHI2DSMVQWIX3LMV43ASLTON2WKOZSGY2DMMRWHE4TIMA. You are receiving this because you are subscribed to this thread.Message ID: @.***>