Open felarof99 opened 21 hours ago
thanks for sending -- we will take a look.
Thanks for reaching out! To debug this further, some additional artifacts will be helpful. Can you run the script with JAX_DUMP_IR_TO=/path
and JAX_TRACEBACK_FILTERING=off
to dump JAX generated IR? This can be used to determine if it's due to an issue such as unsupported ops/patterns
We have a list of Jax Neuron known issues that might be helpful as well: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/jax/setup/jax-neuronx-known-issues.html#jax-neuron-known-issues
Hi,
I am trying to
llama3.2 1B
fine-tuning using AWS Trn1 and I'm running into the following error.Error in eager mode (without jax.jit):
My code is open-source, here's the model definition file.
I tried running the entire training step with JIT and in completely eager mode. Both options failed. I've attached the stack traces below too. jitted_run.txt eager_run_no_jit.txt
Let me know if you need any other info. Thanks a lot for looking into this!