Open Corendos opened 1 week ago
Good catch, we’ll get back to you soon
From: Corentin Godeau @.> Reply-To: aws-neuron/aws-neuron-sdk @.> Date: Saturday, November 9, 2024 at 8:22 AM To: aws-neuron/aws-neuron-sdk @.> Cc: Subscribed @.> Subject: [aws-neuron/aws-neuron-sdk] Compiler crash when sharding model weights (Issue #1030)
Hi !
I was playing with JAX on Neuron recently and came across a bug that is quite annoying.
When trying to shard a very simple MLP layer, depending on the axis you choose, the compilation fails.
Here is a snippet of code that demonstrates the issue:
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
mesh = jax.sharding.Mesh(jax.devices(), ('x'))
BATCH_SIZE = 1
SIZE = 4096
HIDDEN_SIZE = 8192
DTYPE = jnp.bfloat16
SHARD_ON_HIDDEN = True
def mlp(x: jax.Array, gate_up_proj: jax.Array, down_proj: jax.Array) -> jax.Array:
# x: (B, D)
# gate_up_proj: (D, 2 * H)
# down_proj: (H, D)
hidden = jax.lax.dot_general(x, gate_up_proj, (([1], [0]), ([],[])))
# hidden: (B, 2 * H)
x1, x2 = jnp.split(hidden, 2, 1)
hidden = jax.nn.gelu(x1) * x2
# hidden: (B, H)
hidden = jax.lax.dot_general(hidden, down_proj, (([1], [1]), ([],[])))
return hidden
if SHARD_ON_HIDDEN:
weight_sharding = jax.sharding.NamedSharding(mesh, P(None, 'x'))
else:
weight_sharding = jax.sharding.NamedSharding(mesh, P('x', None))
x = jax.ShapeDtypeStruct(shape=(BATCH_SIZE, SIZE), dtype=DTYPE, sharding=jax.sharding.NamedSharding(mesh, P(None, None)))
gate_up_proj = jax.ShapeDtypeStruct((SIZE, 2 * HIDDEN_SIZE), dtype=DTYPE, sharding=weight_sharding)
down_proj = jax.ShapeDtypeStruct((SIZE, HIDDEN_SIZE), dtype=DTYPE, sharding=weight_sharding)
lowered = jax.jit(mlp).lower(x, gate_up_proj, down_proj)
print(lowered.as_text())
compiled = lowered.compile()
print(compiled.as_text())
If SHARD_ON_HIDDEN is set to False everything works fine but when it's set to True (which is something you want for optimal performance), it crashes when using more than 2 Neuron cores (probably because it then introduces interconnect transfers).
Here is the error:
2024-Nov-09 16:14:06.605467 66954:67403 ERROR ENC:enc_parse_replica_groups [nec_dev 2] replica groups (0/1) does not have myself 2
2024-Nov-09 16:14:06.605522 66954:67403 ERROR TDRV:build_enc_source_target_pairs Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.605546 66954:67403 ERROR TDRV:instr_col_translate_ptc2 Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.605670 66954:67403 ERROR TDRV:ib_create_one_block failed to translate instructions
2024-Nov-09 16:14:06.605695 66954:67403 ERROR TDRV:ib_create_eib_v2 Failed to create instruction blocks
2024-Nov-09 16:14:06.605715 66954:67403 ERROR TDRV:sequencer_v2_setup_instr_one_eng Failed to allocate eib
2024-Nov-09 16:14:06.605735 66954:67403 ERROR TDRV:kbl_model_add create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.605981 66954:67403 ERROR NMGR:dlr_kelf_stage Failed to load subgraph
2024-Nov-09 16:14:06.606000 66954:67403 ERROR NMGR:kmgr_load_nn_internal_v2 Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.606012 66954:67403 ERROR NMGR:kmgr_load_nn_post_metrics Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4
2024-Nov-09 16:14:06.606025 66954:67403 ERROR NRT:nrt_infodump Neuron runtime information - please include in any support request:
2024-Nov-09 16:14:06.606040 66954:67403 ERROR NRT:nrt_infodump ------------->8------------[ cut here ]------------>8-------------
2024-Nov-09 16:14:06.606062 66954:67403 ERROR NRT:nrt_infodump NRT version: 2.22.14.0 (6e27b8d5b22dea0e0b8375517f4d8a009b6de5a8)
2024-Nov-09 16:14:06.606084 66954:67403 ERROR NRT:nrt_infodump Embedded FW version: 1.12.2.0 (f152b70c827a52701d6b9ee74ec7ff7a15971f7d)
2024-Nov-09 16:14:06.606112 66954:67403 ERROR NRT:nrt_infodump CCOM version: 2.22.26.0- (compat 48)
2024-Nov-09 16:14:06.606134 66954:67403 ERROR NRT:nrt_infodump Instance ID: i-0b19e4a1cf3fd70d9
2024-Nov-09 16:14:06.606156 66954:67403 ERROR NRT:nrt_infodump Cluster ID: N/A
2024-Nov-09 16:14:06.606178 66954:67403 ERROR NRT:nrt_infodump Kernel: Linux 6.8.0-1015-aws #16~22.04.1-Ubuntu SMP Mon Aug 19 19:38:17 UTC 2024
2024-Nov-09 16:14:06.606200 66954:67403 ERROR NRT:nrt_infodump Nodename: ip-172-31-42-39
2024-Nov-09 16:14:06.606254 66954:67403 ERROR NRT:nrt_infodump Driver version: 2.18.12.0
2024-Nov-09 16:14:06.606276 66954:67403 ERROR NRT:nrt_infodump Failure: NRT_RESOURCE in nrt_load()
2024-Nov-09 16:14:06.606298 66954:67403 ERROR NRT:nrt_infodump Visible cores: 0, 1, 2, 3
2024-Nov-09 16:14:06.606318 66954:67403 ERROR NRT:nrt_infodump Environment:
2024-Nov-09 16:14:06.606341 66954:67403 ERROR NRT:nrt_infodump NEURON_CC_FLAGS=--model-type=transformer --auto-cast=none
2024-Nov-09 16:14:06.606362 66954:67403 ERROR NRT:nrt_infodump NEURON_RT_NUM_CORES=4
2024-Nov-09 16:14:06.606382 66954:67403 ERROR NRT:nrt_infodump NEURON_RT_ROOT_COMM_ID=localhost:49255
2024-Nov-09 16:14:06.606401 66954:67403 ERROR NRT:nrt_infodump -------------8<-----------[ cut to here ]-----------8<------------
2024-Nov-09 16:14:06.602248 66954:67406 ERROR ENC:enc_parse_replica_groups [nec_dev 1] replica groups (0/1) does not have myself 1
2024-Nov-09 16:14:06.603586 66954:67404 ERROR ENC:enc_parse_replica_groups [nec_dev 3] replica groups (0/1) does not have myself 3
2024-Nov-09 16:14:06.614656 66954:67406 ERROR TDRV:build_enc_source_target_pairs Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.626936 66954:67404 ERROR TDRV:build_enc_source_target_pairs Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.639079 66954:67406 ERROR TDRV:instr_col_translate_ptc2 Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.649859 66954:67404 ERROR TDRV:instr_col_translate_ptc2 Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.662122 66954:67406 ERROR TDRV:ib_create_one_block failed to translate instructions
2024-Nov-09 16:14:06.672477 66954:67404 ERROR TDRV:ib_create_one_block failed to translate instructions
2024-Nov-09 16:14:06.682738 66954:67406 ERROR TDRV:ib_create_eib_v2 Failed to create instruction blocks
2024-Nov-09 16:14:06.692175 66954:67404 ERROR TDRV:ib_create_eib_v2 Failed to create instruction blocks
2024-Nov-09 16:14:06.702800 66954:67406 ERROR TDRV:sequencer_v2_setup_instr_one_eng Failed to allocate eib
2024-Nov-09 16:14:06.712256 66954:67404 ERROR TDRV:sequencer_v2_setup_instr_one_eng Failed to allocate eib
2024-Nov-09 16:14:06.723884 66954:67406 ERROR TDRV:kbl_model_add create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.751248 66954:67405 ERROR ENC:enc_parse_replica_groups [nec_dev 0] replica groups (0/1) does not have myself 0
2024-Nov-09 16:14:06.751330 66954:67405 ERROR TDRV:build_enc_source_target_pairs Failed to parse src_target_pairs on ENC
2024-Nov-09 16:14:06.751345 66954:67405 ERROR TDRV:instr_col_translate_ptc2 Failed to setup pseudo trigger collectives2 instruction
2024-Nov-09 16:14:06.751929 66954:67405 ERROR TDRV:ib_create_one_block failed to translate instructions
2024-Nov-09 16:14:06.751957 66954:67405 ERROR TDRV:ib_create_eib_v2 Failed to create instruction blocks
2024-Nov-09 16:14:06.751974 66954:67405 ERROR TDRV:sequencer_v2_setup_instr_one_eng Failed to allocate eib
2024-Nov-09 16:14:06.751990 66954:67405 ERROR TDRV:kbl_model_add create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.753195 66954:67405 ERROR NMGR:dlr_kelf_stage Failed to load subgraph
2024-Nov-09 16:14:06.753230 66954:67405 ERROR NMGR:kmgr_load_nn_internal_v2 Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.753248 66954:67405 ERROR NMGR:kmgr_load_nn_post_metrics Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4
2024-Nov-09 16:14:06.735901 66954:67404 ERROR TDRV:kbl_model_add create_engine_refill_rings_v1() error
2024-Nov-09 16:14:06.749448 66954:67406 ERROR NMGR:dlr_kelf_stage Failed to load subgraph
2024-Nov-09 16:14:06.762705 66954:67404 ERROR NMGR:dlr_kelf_stage Failed to load subgraph
2024-Nov-09 16:14:06.775586 66954:67406 ERROR NMGR:kmgr_load_nn_internal_v2 Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.789108 66954:67404 ERROR NMGR:kmgr_load_nn_internal_v2 Failed to stage graph: kelf-0.json to NeuronCore
2024-Nov-09 16:14:06.799686 66954:67406 ERROR NMGR:kmgr_load_nn_post_metrics Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4
2024-Nov-09 16:14:06.810033 66954:67404 ERROR NMGR:kmgr_load_nn_post_metrics Failed to load NN: /tmp/tmpsyndo_96/file.neff, err: 4
Environment
Thanks for the help !
— Reply to this email directly, view it on GitHubhttps://github.com/aws-neuron/aws-neuron-sdk/issues/1030, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AFTRWCOMYDUQSVYAGLEL2GDZ7YZDHAVCNFSM6AAAAABRPIM2U6VHI2DSMVQWIX3LMV43ASLTON2WKOZSGY2DMMRWHE4TIMA. You are receiving this because you are subscribed to this thread.Message ID: @.***>
Hi !
I was playing with JAX on Neuron recently and came across a bug that is quite annoying.
When trying to shard a very simple MLP layer, depending on the axis you choose, the compilation fails.
Here is a snippet of code that demonstrates the issue:
If
SHARD_ON_HIDDEN
is set toFalse
everything works fine but when it's set toTrue
(which is something you want for optimal performance), it crashes when using more than 2 Neuron cores (probably because it then introduces interconnect transfers).Here is the error:
Environment
neuronx-cc==2.15.141.0+d3cfc8ca
libneuronxla==2.0.4986.0
jaxlib==0.4.31
jax-neuronx==0.1.1
jax==0.4.31
inf2.48xlarge
instanceThanks for the help !