openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.64k stars 418 forks source link

Dynamic slice fusion causes failure with maxtext #17938

Closed shraiysh closed 7 hours ago

shraiysh commented 1 week ago

Llama2-7b with the following parameters fails on maxtext (JAX toolbox)

test-maxtext.sh -b 4 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output -a "scan_layers=true max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false"

Tested on the latest container: https://github.com/nvidia/JAX-Toolbox/pkgs/container/jax/284233276?tag=maxtext-2024-10-04

This is a temporary fix to unblock the release.

dimitar-asenov commented 14 hours ago

I think this PR is obsolete since: https://github.com/tensorflow/tensorflow/commit/1cc871a25a10de22f5c292a513052be173d4e619

@jreiffers Is the test in this PR perhaps valuable? If yes, please approve and we can merge it. Otherwise we can close it right away.

shraiysh commented 7 hours ago

Closing this PR as it should not be required anymore. The command buffering test fixes were related to the flag.