jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.51k stars 2.8k forks source link

Jax determinism on GPU #13672

Open dgrangier opened 1 year ago

dgrangier commented 1 year ago

Description

I have trouble with determinism for Jax code on GPU but not on CPU. This is easy to reproduce from the flax examples https://github.com/google/flax/tree/main/examples .

From https://github.com/google/jax/issues/565, I had the impression that setting TF_DETERMINISTIC_OPS=1 and TF_CUDNN_DETERMINISTIC=1 would make the GPU run deterministic.

After 2 steps, the MNIST example give the same results on GPU or CPU:

4bda0ea4f536031c4619c394b64ede43 /tmp/mnist_cpu_1/checkpoint_1 4bda0ea4f536031c4619c394b64ede43 /tmp/mnist_cpu_2/checkpoint_1 1daae8e25344c91aea35c4836b46c798 /tmp/mnist_gpu_1/checkpoint_1 1daae8e25344c91aea35c4836b46c798 /tmp/mnist_gpu_2/checkpoint_1

This is not the case for the WMT transformer example where only the CPU run are deterministic.

cbcf1e88ed7b2bd61376e3c90c5dcdfc /tmp/wmt_cpu_1/checkpoint_1 cbcf1e88ed7b2bd61376e3c90c5dcdfc /tmp/wmt_cpu_2/checkpoint_1 3dd11fb219e9f37dd90542267cadc86e /tmp/wmt_gpu_1/checkpoint_1 3067f6d66fdf7e8b52baef7056422291 /tmp/wmt_gpu_2/checkpoint_1

Is this expected? Am I missing some flags to make my GPU run deterministic?

BTW I have the following setup CuDNN 8.4, tf = 2.11.0, jax = 0.3.13, flax = 0.5.3

Note: I filled the same issue for flax as well as they might want to be aware or look at it as well, see https://github.com/google/flax/issues/2700

What jax/jaxlib version are you using?

0.3.13

Which accelerator(s) are you using?

GPU

Additional system info

CuDNN 8.4, tf = 2.11.0, jax = 0.3.13, flax = 0.5.3

NVIDIA GPU info

NVIDIA-SMI 470.57.02 Driver Version: 470.57.02 CUDA Version: 11.4

yashk2810 commented 1 year ago

Can you upgrade your jax and jaxlib version to the latest 0.4.1 version? (also upgrade your flax version?)

dgrangier commented 1 year ago

I updated to Jax=0.4.1 and Flax=0.6.3 I still observe the same behavior: CPU runs are deterministic for flax/examples/wmt but not the GPU ones:

3acaaa3299a06957ce5d51b7c95bf13f /tmp/wmt_cpu_1/checkpoint_1 3acaaa3299a06957ce5d51b7c95bf13f /tmp/wmt_cpu_2/checkpoint_1 4f07b922eea4a92d6a2a1634d13191e7 /tmp/wmt_gpu_1/checkpoint_1 fa9011b54a845df5a2e54a5545625ec9 /tmp/wmt_gpu_2/checkpoint_1

dgrangier commented 1 year ago

Ping. Any idea how to get deterministic behavior on GPU?

jakevdp commented 1 year ago

From #565:

XLA:GPU reductions are nondeterministic, though. Changing this would be a lot of work. If you all wanted us to prioritize it, we should talk to understand the costs/benefits.

I'm not aware of that having changed, maybe @yashk2810 can weigh-in?

yashk2810 commented 1 year ago

Did you update jaxlib to 0.4.1?

dgrangier commented 1 year ago

Yes. See my message of Dec 16, 2022.

Are XLA:GPU reductions nondeterministic? Even with TF_DETERMINISTIC_OPS=1 and TF_CUDNN_DETERMINISTIC=1 ? Where could I read about this?

mattjj commented 1 year ago

Hey @dgrangier, sorry for the delay.

I think we need XLA_FLAGS='--xla_gpu_deterministic_ops=true'. With that flag, all XLA:GPU operations are deterministic or they will loudly error, and any such loud error is an XLA:GPU bug. (We actually don't need TF_DETERMINISTIC_OPS=1 for XLA:GPU itself anymore, given the XLA_FLAGS option just mentioned, though TF_DETERMINISTIC_ops=1 may still be useful for tf.data if you're using that, which it looks like the example is indeed using.)

Thanks to @reedwm for explaining this.

Can you try out XLA_FLAGS='--xla_gpu_deterministic_ops=true' and verify that it works for you?

long21wt commented 1 year ago

Hey @dgrangier, sorry for the delay.

I think we need XLA_FLAGS='--xla_gpu_deterministic_ops=true'. With that flag, all XLA:GPU operations are deterministic or they will loudly error, and any such loud error is an XLA:GPU bug. (We actually don't need TF_DETERMINISTIC_OPS=1 for XLA:GPU itself anymore, given the XLA_FLAGS option just mentioned, though TF_DETERMINISTIC_ops=1 may still be useful for tf.data if you're using that, which it looks like the example is indeed using.)

Thanks to @reedwm for explaining this.

Can you try out XLA_FLAGS='--xla_gpu_deterministic_ops=true' and verify that it works for you?

Having the same issue on different code, the flag indeed works for me, however it slows down the training time. Do you have any other idea to balance the trade-off ?