jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.61k stars 2.82k forks source link

Flaky test tests/fft_test.py::FftTest::testFftfreq5 #24798

Open apivovarov opened 3 weeks ago

apivovarov commented 3 weeks ago

Description

I tried to run tests/fft_test.py::FftTest::testFftfreq5 on g5.24xlarge instance which has 4 GPU devices.

The test works fine if executed individually

pytest -s -v tests/fft_test.py::FftTest::testFftfreq5

But the test constantly failed if executed as part of tests/fft_test.py execution (even with a single pytest worker mode)

Problem that x1 and x2 arguments are on different devices - dev0 and dev3

pytest -n 1 -s -v tests/fft_test.py

fun = <function true_divide at 0x796a1ebc3d00>
jit_info = PjitInfo(fun_sourceinfo='true_divide at /home/ubuntu/workspace/jax/jax/_src/numpy/ufuncs.py:2292', fun_signature=<Sign...e, backend=None, keep_unused=False, inline=True, abstracted_axes=None, use_resource_env=False, compiler_options_kvs=())
args = (Array([ 0.,  1.,  2.,  3.,  4., -5., -4., -3., -2., -1.], dtype=float32), Array(1., dtype=float32)), kwargs = {}
p = PjitParams(consts=[], params={'jaxpr': { lambda ; a:f32[10] b:f32[]. let c:f32[10] = div a b in (c,) }, 'in_shardings'...*), {})), out_tree=PyTreeDef(*), donated_invars=(False, False), arg_names=('x1', 'x2'), num_consts=0, attrs_tracked=[])
args_flat = [Array([ 0.,  1.,  2.,  3.,  4., -5., -4., -3., -2., -1.], dtype=float32), Array(1., dtype=float32)], arg = Array(1., dtype=float32)
fails = [DeviceAssignmentMismatch(da=(CudaDevice(id=0),), m_type=<MismatchType.ARG_SHARDING: 0>, source_info=None), DeviceAssignmentMismatch(da=(CudaDevice(id=3),), m_type=<MismatchType.ARG_SHARDING: 0>, source_info=None)]
api_name = 'jit', fun_name = 'true_divide'
msg = 'Received incompatible devices for jitted computation. Got argument x1 of true_divide with shape float32[10] and device ids [0] on platform GPU and argument x2 of true_divide with shape float32[] and device ids [3] on platform GPU'

    def _python_pjit_helper(fun, jit_info, *args, **kwargs):
      p, args_flat = _infer_params(fun, jit_info, args, kwargs)

      for arg in args_flat:
        dispatch.check_arg(arg)

      if p.attrs_tracked:
        init_states = _get_states(p.attrs_tracked)
        args_flat = [*init_states, *args_flat]

      try:
        out_flat = pjit_p.bind(*args_flat, **p.params)
      except pxla.DeviceAssignmentMismatchError as e:
        fails, = e.args
        api_name = 'jit' if p.params['resource_env'] is None else 'pjit'
        fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
        msg = _device_assignment_mismatch_error(
            fun_name, fails, args_flat, api_name, p.arg_names)
>       raise ValueError(msg) from None
E       ValueError: Received incompatible devices for jitted computation. Got argument x1 of true_divide with shape float32[10] and device ids [0] on platform GPU and argument x2 of true_divide with shape float32[] and device ids [3] on platform GPU

jax/_src/pjit.py:195: ValueError
================================================================================== short test summary info ==================================================================================
FAILED tests/fft_test.py::FftTest::testFftfreq5 - ValueError: Received incompatible devices for jitted computation. Got argument x1 of true_divide with shape float32[10] and device ids [0] on platform GPU and argument x2 of true_divid...
========================================================================= 1 failed, 96 passed, 2 skipped in 13.86s ==========================================================================

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.36.dev20241007+86038f84e
jaxlib: 0.4.35
numpy:  2.1.3
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
device info: NVIDIA A10G-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ip-172-31-15-167', release='6.8.0-1018-aws', version='#19~22.04.1-Ubuntu SMP Wed Oct  9 16:48:22 UTC 2024', machine='x86_64')

$ nvidia-smi
Fri Nov  8 18:29:02 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A10G                    Off |   00000000:00:1B.0 Off |                    0 |
|  0%   19C    P0             29W /  300W |     259MiB /  23028MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A10G                    Off |   00000000:00:1C.0 Off |                    0 |
|  0%   19C    P0             27W /  300W |     259MiB /  23028MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A10G                    Off |   00000000:00:1D.0 Off |                    0 |
|  0%   20C    P0             29W /  300W |     259MiB /  23028MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A10G                    Off |   00000000:00:1E.0 Off |                    0 |
|  0%   19C    P0             27W /  300W |     259MiB /  23028MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A    537624      C   python3                                       250MiB |
|    1   N/A  N/A    537624      C   python3                                       250MiB |
|    2   N/A  N/A    537624      C   python3                                       250MiB |
|    3   N/A  N/A    537624      C   python3                                       250MiB |
+-----------------------------------------------------------------------------------------+
hawkinsp commented 3 weeks ago

We've been debugging this after seeing it in our own CI. Fix coming soon, hopefully.

apivovarov commented 3 weeks ago

Thank you! Another test which failed on 4GPUs setup is described here - https://github.com/jax-ml/jax/pull/24796 @hawkinsp