Per train step:
Total TFLOPs: 377.53
split as 86.02% learnable weight flops and 13.98% attention flops
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/opt/maxtext/MaxText/train.py", line 776, in <module>
app.run(main)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/opt/maxtext/MaxText/train.py", line 772, in main
train_loop(config)
File "/opt/maxtext/MaxText/train.py", line 666, in train_loop
state, metrics = p_train_step(state, example_batch, nextrng)
RuntimeError: Invalid opaque object size
Minimal steps to reproduce on one node with at least 2 A100 (or H100) GPUs:
I note that "Invalid opaque object size" is an error that comes from NVIDIA's TransformerEngine. Can you say a bit more why you think this is an XLA bug?
The issue started with https://github.com/openxla/xla/pull/18052
Error log:
Minimal steps to reproduce on one node with at least 2 A100 (or H100) GPUs:
The above command, we used single node with 8GPUs on it (fsdp equals to the number of GPUs)