Open apivovarov opened 3 weeks ago
I tried to run tests/fft_test.py::FftTest::testFftfreq5 on g5.24xlarge instance which has 4 GPU devices.
tests/fft_test.py::FftTest::testFftfreq5
The test works fine if executed individually
pytest -s -v tests/fft_test.py::FftTest::testFftfreq5
But the test constantly failed if executed as part of tests/fft_test.py execution (even with a single pytest worker mode)
Problem that x1 and x2 arguments are on different devices - dev0 and dev3
pytest -n 1 -s -v tests/fft_test.py fun = <function true_divide at 0x796a1ebc3d00> jit_info = PjitInfo(fun_sourceinfo='true_divide at /home/ubuntu/workspace/jax/jax/_src/numpy/ufuncs.py:2292', fun_signature=<Sign...e, backend=None, keep_unused=False, inline=True, abstracted_axes=None, use_resource_env=False, compiler_options_kvs=()) args = (Array([ 0., 1., 2., 3., 4., -5., -4., -3., -2., -1.], dtype=float32), Array(1., dtype=float32)), kwargs = {} p = PjitParams(consts=[], params={'jaxpr': { lambda ; a:f32[10] b:f32[]. let c:f32[10] = div a b in (c,) }, 'in_shardings'...*), {})), out_tree=PyTreeDef(*), donated_invars=(False, False), arg_names=('x1', 'x2'), num_consts=0, attrs_tracked=[]) args_flat = [Array([ 0., 1., 2., 3., 4., -5., -4., -3., -2., -1.], dtype=float32), Array(1., dtype=float32)], arg = Array(1., dtype=float32) fails = [DeviceAssignmentMismatch(da=(CudaDevice(id=0),), m_type=<MismatchType.ARG_SHARDING: 0>, source_info=None), DeviceAssignmentMismatch(da=(CudaDevice(id=3),), m_type=<MismatchType.ARG_SHARDING: 0>, source_info=None)] api_name = 'jit', fun_name = 'true_divide' msg = 'Received incompatible devices for jitted computation. Got argument x1 of true_divide with shape float32[10] and device ids [0] on platform GPU and argument x2 of true_divide with shape float32[] and device ids [3] on platform GPU' def _python_pjit_helper(fun, jit_info, *args, **kwargs): p, args_flat = _infer_params(fun, jit_info, args, kwargs) for arg in args_flat: dispatch.check_arg(arg) if p.attrs_tracked: init_states = _get_states(p.attrs_tracked) args_flat = [*init_states, *args_flat] try: out_flat = pjit_p.bind(*args_flat, **p.params) except pxla.DeviceAssignmentMismatchError as e: fails, = e.args api_name = 'jit' if p.params['resource_env'] is None else 'pjit' fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) msg = _device_assignment_mismatch_error( fun_name, fails, args_flat, api_name, p.arg_names) > raise ValueError(msg) from None E ValueError: Received incompatible devices for jitted computation. Got argument x1 of true_divide with shape float32[10] and device ids [0] on platform GPU and argument x2 of true_divide with shape float32[] and device ids [3] on platform GPU jax/_src/pjit.py:195: ValueError ================================================================================== short test summary info ================================================================================== FAILED tests/fft_test.py::FftTest::testFftfreq5 - ValueError: Received incompatible devices for jitted computation. Got argument x1 of true_divide with shape float32[10] and device ids [0] on platform GPU and argument x2 of true_divid... ========================================================================= 1 failed, 96 passed, 2 skipped in 13.86s ==========================================================================
>>> import jax; jax.print_environment_info() jax: 0.4.36.dev20241007+86038f84e jaxlib: 0.4.35 numpy: 2.1.3 python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] device info: NVIDIA A10G-4, 4 local devices" process_count: 1 platform: uname_result(system='Linux', node='ip-172-31-15-167', release='6.8.0-1018-aws', version='#19~22.04.1-Ubuntu SMP Wed Oct 9 16:48:22 UTC 2024', machine='x86_64') $ nvidia-smi Fri Nov 8 18:29:02 2024 +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA A10G Off | 00000000:00:1B.0 Off | 0 | | 0% 19C P0 29W / 300W | 259MiB / 23028MiB | 2% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 1 NVIDIA A10G Off | 00000000:00:1C.0 Off | 0 | | 0% 19C P0 27W / 300W | 259MiB / 23028MiB | 2% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 2 NVIDIA A10G Off | 00000000:00:1D.0 Off | 0 | | 0% 20C P0 29W / 300W | 259MiB / 23028MiB | 1% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 3 NVIDIA A10G Off | 00000000:00:1E.0 Off | 0 | | 0% 19C P0 27W / 300W | 259MiB / 23028MiB | 1% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ +-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | 0 N/A N/A 537624 C python3 250MiB | | 1 N/A N/A 537624 C python3 250MiB | | 2 N/A N/A 537624 C python3 250MiB | | 3 N/A N/A 537624 C python3 250MiB | +-----------------------------------------------------------------------------------------+
We've been debugging this after seeing it in our own CI. Fix coming soon, hopefully.
Thank you! Another test which failed on 4GPUs setup is described here - https://github.com/jax-ml/jax/pull/24796 @hawkinsp
Description
I tried to run
tests/fft_test.py::FftTest::testFftfreq5
on g5.24xlarge instance which has 4 GPU devices.The test works fine if executed individually
But the test constantly failed if executed as part of tests/fft_test.py execution (even with a single pytest worker mode)
Problem that x1 and x2 arguments are on different devices - dev0 and dev3
System info (python version, jaxlib version, accelerator, etc.)