openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.64k stars 414 forks source link

PR#18052 caused runtime crashes for all the MaxText training with multi-gpus #18214

Open gpupuck opened 8 hours ago

gpupuck commented 8 hours ago

The issue started with https://github.com/openxla/xla/pull/18052

Error log:

Per train step:
 Total TFLOPs: 377.53 
 split as 86.02% learnable weight flops and 13.98% attention flops
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/maxtext/MaxText/train.py", line 776, in <module>
    app.run(main)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/opt/maxtext/MaxText/train.py", line 772, in main
    train_loop(config)
  File "/opt/maxtext/MaxText/train.py", line 666, in train_loop
    state, metrics = p_train_step(state, example_batch, nextrng)
RuntimeError: Invalid opaque object size

Minimal steps to reproduce on one node with at least 2 A100 (or H100) GPUs:

docker run -it --rm --gpus=all --shm-size=2g ghcr.io/nvidia/jax:maxtext-2024-10-10
test-maxtext.sh -b 4 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output -a "scan_layers=true max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false"

The above command, we used single node with 8GPUs on it (fsdp equals to the number of GPUs)

hawkinsp commented 7 hours ago

I note that "Invalid opaque object size" is an error that comes from NVIDIA's TransformerEngine. Can you say a bit more why you think this is an XLA bug?