google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.23k stars 2.68k forks source link

Slow initialisation on H100 #21474

Open ZacCranko opened 1 month ago

ZacCranko commented 1 month ago

Description

import time 

tic = time.time()
import jax
print(jax.devices())
print("Time to import jax and get devices:", time() - tic)

Time to import jax and get devices: 3.9980173110961914

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='h100-instance-4', release='6.5.0-1020-gcp', version='#20~22.04.1-Ubuntu SMP Wed May  1 02:03:24 UTC 2024', machine='x86_64')

$ nvidia-smi
Wed May 29 01:31:57 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.29.06              Driver Version: 545.29.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          Off | 00000000:04:00.0 Off |                    0 |
| N/A   38C    P0             114W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          Off | 00000000:05:00.0 Off |                    0 |
| N/A   37C    P0             113W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          Off | 00000000:0A:00.0 Off |                    0 |
| N/A   39C    P0             120W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          Off | 00000000:0B:00.0 Off |                    0 |
| N/A   36C    P0             112W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          Off | 00000000:84:00.0 Off |                    0 |
| N/A   38C    P0             116W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          Off | 00000000:85:00.0 Off |                    0 |
| N/A   36C    P0             119W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          Off | 00000000:8A:00.0 Off |                    0 |
| N/A   36C    P0             114W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          Off | 00000000:8B:00.0 Off |                    0 |
| N/A   36C    P0             111W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A   1537304      C   python                                      526MiB |
|    1   N/A  N/A   1537304      C   python                                      526MiB |
|    2   N/A  N/A   1537304      C   python                                      526MiB |
|    3   N/A  N/A   1537304      C   python                                      526MiB |
|    4   N/A  N/A   1537304      C   python                                      526MiB |
|    5   N/A  N/A   1537304      C   python                                      526MiB |
|    6   N/A  N/A   1537304      C   python                                      526MiB |
|    7   N/A  N/A   1537304      C   python                                      526MiB |
+---------------------------------------------------------------------------------------+
hawkinsp commented 1 month ago

Just to follow up: 4s is better than you had mentioned to me offline! Is 4s slow enough to be a problem? Not saying we can't make it faster, but how fast is fast enough?