aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
https://aws.amazon.com/machine-learning/neuron/
Other
466 stars 154 forks source link

JAX ERROR -- (jit(_multiply)/jit(main)/mul_multiply.0) Internal tensorizer error: BirCodeGenLoop:BIRCodegen does not support broadcast patterns, but found one in {0,+,0}[128] #1044

Open felarof99 opened 21 hours ago

felarof99 commented 21 hours ago

Hi,

I am trying to llama3.2 1B fine-tuning using AWS Trn1 and I'm running into the following error.

Error in eager mode (without jax.jit):

2024-11-21 04:44:13.000699:  3926  ERROR ||NEURON_CC_WRAPPER||: Failed compilation with ['neuronx-cc', 'compile', '--target=trn1', '--framework=XLA', '/tmp/ubuntu/neuroncc_compile_workdir/19af4718-112f-4fb8-92fb-ec725d3f5334/model.MODULE_5516390676483383119+d7517139.hlo_module.pb', '--output', '/tmp/ubuntu/neuroncc_compile_workdir/19af4718-112f-4fb8-92fb-ec725d3f5334/model.MODULE_5516390676483383119+d7517139.neff', '--verbose=35']: 2024-11-21T04:44:13Z [TEN404] (jit(_multiply)/jit(main)/mul_multiply.0) Internal tensorizer error: BirCodeGenLoop:BIRCodegen does not support broadcast patterns, but found one in {0,+,0}[128] - Please open a support ticket at https://github.com/aws-neuron/aws-neuron-sdk/issues/new. You may also be able to obtain more information using the 'XLA_IR_DEBUG' and 'XLA_HLO_DEBUG' environment variables.

My code is open-source, here's the model definition file.

I tried running the entire training step with JIT and in completely eager mode. Both options failed. I've attached the stack traces below too. jitted_run.txt eager_run_no_jit.txt

Let me know if you need any other info. Thanks a lot for looking into this!

awsrjh commented 15 hours ago

thanks for sending -- we will take a look.

fayyadd commented 14 hours ago

Thanks for reaching out! To debug this further, some additional artifacts will be helpful. Can you run the script with JAX_DUMP_IR_TO=/path and JAX_TRACEBACK_FILTERING=off to dump JAX generated IR? This can be used to determine if it's due to an issue such as unsupported ops/patterns

We have a list of Jax Neuron known issues that might be helpful as well: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/jax/setup/jax-neuronx-known-issues.html#jax-neuron-known-issues