pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

XLA2 does not work with jax 0.4.34 (but did work on jax 0.4.33) #8240

Open Chaosruler972 opened 1 month ago

Chaosruler972 commented 1 month ago

🐛 Bug

A toy example of MNIST using XLA2 does not work on the latest version of jax (0.4.34) on Trillium machine of 64 cores (V6e-64) but downgrading to 0.4.33 fixes the issue

To Reproduce

  1. Download the toy training example from here

  2. Allocate a V6e-64 trillium TPU at GCP

  3. copy that file using gcp scp to all the VM machines

  4. prepare an environment containing torch_xla2 (refer to the readme here)

  5. install 0.4.43 jax/lib from pip

    install jax==0.4.33 jaxlib==0.4.33 libtpu-nightly==0.1.dev20241008+nightly -f https://storage.googleapis.com/libtpu-releases/index.html
  6. run your training, verify it is working well

  7. upgrade to jax 0.4.44

    install jax==0.4.33 jaxlib==0.4.33 libtpu-nightly==0.1.dev20241008+nightly -f https://storage.googleapis.com/libtpu-releases/index.html
  8. run your training again, note how the training loop exits without warning/messages after the loss was extracted

Expected behavior

small varying results between the scripts when running on different version of jax

Environment

qihqi commented 1 month ago

pip install jax==0.4.34 jaxlib==0.4.34 libtpu-nightly==0.1.dev20241008+nightly -f https://storage.googleapis.com/libtpu-releases/index.html

Chaosruler972 commented 1 month ago

I made a mistake on the issue page, I installed using

pip install jax==0.4.34 jaxlib==0.4.34 libtpu-nightly==0.1.dev20241008+nightly -f https://storage.googleapis.com/libtpu-
releases/index.html

on the second experiement, which lead to the issue