jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.47k stars 2.8k forks source link

Check mem_stats key in test_lax_full_like_efficient #24796

Open apivovarov opened 4 days ago

apivovarov commented 4 days ago

Problem

I tried to run multi_device_test.py on aws g5.24xlarge gpu instance having 4 GPU devices (NVIDIA A10G GA102GL)

Got error - KeyError: 'bytes_reservable_limit':

pytest -n 1 -s -v tests/multi_device_test.py

=================================== FAILURES ===================================
_________________ MultiDeviceTest.test_lax_full_like_efficient _________________
[gw0] linux -- Python 3.10.12 /usr/bin/python3

self = <multi_device_test.MultiDeviceTest testMethod=test_lax_full_like_efficient>

    def test_lax_full_like_efficient(self):
      devices = self.get_devices()
      if len(devices) < 4:
        self.skipTest("test requires 4 devices")
      mem_stats = devices[0].memory_stats()
      if mem_stats is None:
        self.skipTest('Only can run test on device with mem_stats')
      mesh = Mesh(devices, axis_names=("i"))
      sharding = NamedSharding(mesh, P('i'))
>     available_memory = mem_stats['bytes_reservable_limit']
E     KeyError: 'bytes_reservable_limit'

tests/multi_device_test.py:315: KeyError
=========================== short test summary info ============================
FAILED tests/multi_device_test.py::MultiDeviceTest::test_lax_full_like_efficient

Solution:

This PR adds additional check for the presence of the bytes_reservable_limit key in mem_stats.

Result:

tests/multi_device_test.py::MultiDeviceTest::test_lax_full_like_efficient 
[gw0] SKIPPED tests/multi_device_test.py::MultiDeviceTest::test_lax_full_like_efficient 

Environment Info:

>>> import jax; jax.print_environment_info()
jax:    0.4.36.dev20241007+86038f84e
jaxlib: 0.4.35
numpy:  2.1.3
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
device info: NVIDIA A10G-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ip-172-31-15-167', release='6.8.0-1018-aws', version='#19~22.04.1-Ubuntu SMP Wed Oct  9 16:48:22 UTC 2024', machine='x86_64')

$ nvidia-smi
Fri Nov  8 18:29:02 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A10G                    Off |   00000000:00:1B.0 Off |                    0 |
|  0%   19C    P0             29W /  300W |     259MiB /  23028MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A10G                    Off |   00000000:00:1C.0 Off |                    0 |
|  0%   19C    P0             27W /  300W |     259MiB /  23028MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A10G                    Off |   00000000:00:1D.0 Off |                    0 |
|  0%   20C    P0             29W /  300W |     259MiB /  23028MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A10G                    Off |   00000000:00:1E.0 Off |                    0 |
|  0%   19C    P0             27W /  300W |     259MiB /  23028MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A    537624      C   python3                                       250MiB |
|    1   N/A  N/A    537624      C   python3                                       250MiB |
|    2   N/A  N/A    537624      C   python3                                       250MiB |
|    3   N/A  N/A    537624      C   python3                                       250MiB |
+-----------------------------------------------------------------------------------------+