Open Chaosruler972 opened 1 month ago
pip install jax==0.4.34 jaxlib==0.4.34 libtpu-nightly==0.1.dev20241008+nightly -f https://storage.googleapis.com/libtpu-releases/index.html
I made a mistake on the issue page, I installed using
pip install jax==0.4.34 jaxlib==0.4.34 libtpu-nightly==0.1.dev20241008+nightly -f https://storage.googleapis.com/libtpu-
releases/index.html
on the second experiement, which lead to the issue
🐛 Bug
A toy example of MNIST using XLA2 does not work on the latest version of jax (0.4.34) on Trillium machine of 64 cores (V6e-64) but downgrading to 0.4.33 fixes the issue
To Reproduce
Download the toy training example from here
Allocate a V6e-64 trillium TPU at GCP
copy that file using gcp scp to all the VM machines
prepare an environment containing torch_xla2 (refer to the readme here)
install 0.4.43 jax/lib from pip
run your training, verify it is working well
upgrade to jax 0.4.44
run your training again, note how the training loop exits without warning/messages after the loss was extracted
Expected behavior
small varying results between the scripts when running on different version of jax
Environment