Open Corendos opened 3 days ago
Thank you for reporting the issue, we are able to reproduce the issue with the provided sample. We currently do not support all sharding configurations. Meshes which use non-connected devices might result in runtime failures during execution. The topology for inferentia instances can be seen here. Or by using neuron-ls --topology
.
Thank you for your answer !
So, now that we know for sure that it's not supported, I have some followup questions:
Summary
Hi again! I encountered a bug while playing with attention and sharding in JAX. The issue occurs with specific sharding setups and fails under certain core configurations.
Steps to Reproduce
The following code snippet can reproduce the issue:
Observed Behavior
When
SHARDING_TYPE
is set to"simple"
NEURON_RT_NUM_CORES=16 python xxx.py
When
SHARDING_TYPE
is set to"double"
NEURON_RT_NUM_CORES=16 python xxx.py
When
SHARDING_TYPE
is"none"
(sharding disabled)NEURON_RT_NUM_CORES=16 python xxx.py
Additional Observations
When I increased the number of heads to:
and used
NEURON_RT_NUM_CORES=24
, the script worked even with sharding enabled. Interestingly, with these higher head counts, settingNEURON_RT_NUM_CORES=16
made the"simple"
case work (but not the"double"
).Expected Behavior
The code should work consistently across different sharding configurations and core settings.
Environment
neuronx-cc==2.15.141.0+d3cfc8ca
libneuronxla==2.0.4986.0
jaxlib==0.4.31
jax-neuronx==0.1.1
jax==0.4.31
inf2.48xlarge
instanceAdditional Information
Please let me know if further details are needed for reproducing or debugging the issue. Thank you!