Open ZacCranko opened 1 month ago
import time tic = time.time() import jax print(jax.devices()) print("Time to import jax and get devices:", time() - tic)
Time to import jax and get devices: 3.9980173110961914
>>> import jax; jax.print_environment_info() jax: 0.4.28 jaxlib: 0.4.28 numpy: 1.26.4 python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)] process_count: 1 platform: uname_result(system='Linux', node='h100-instance-4', release='6.5.0-1020-gcp', version='#20~22.04.1-Ubuntu SMP Wed May 1 02:03:24 UTC 2024', machine='x86_64') $ nvidia-smi Wed May 29 01:31:57 2024 +---------------------------------------------------------------------------------------+ | NVIDIA-SMI 545.29.06 Driver Version: 545.29.06 CUDA Version: 12.3 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA H100 80GB HBM3 Off | 00000000:04:00.0 Off | 0 | | N/A 38C P0 114W / 700W | 539MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 1 NVIDIA H100 80GB HBM3 Off | 00000000:05:00.0 Off | 0 | | N/A 37C P0 113W / 700W | 539MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 2 NVIDIA H100 80GB HBM3 Off | 00000000:0A:00.0 Off | 0 | | N/A 39C P0 120W / 700W | 539MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 3 NVIDIA H100 80GB HBM3 Off | 00000000:0B:00.0 Off | 0 | | N/A 36C P0 112W / 700W | 539MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 4 NVIDIA H100 80GB HBM3 Off | 00000000:84:00.0 Off | 0 | | N/A 38C P0 116W / 700W | 539MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 5 NVIDIA H100 80GB HBM3 Off | 00000000:85:00.0 Off | 0 | | N/A 36C P0 119W / 700W | 539MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 6 NVIDIA H100 80GB HBM3 Off | 00000000:8A:00.0 Off | 0 | | N/A 36C P0 114W / 700W | 539MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 7 NVIDIA H100 80GB HBM3 Off | 00000000:8B:00.0 Off | 0 | | N/A 36C P0 111W / 700W | 539MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ +---------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=======================================================================================| | 0 N/A N/A 1537304 C python 526MiB | | 1 N/A N/A 1537304 C python 526MiB | | 2 N/A N/A 1537304 C python 526MiB | | 3 N/A N/A 1537304 C python 526MiB | | 4 N/A N/A 1537304 C python 526MiB | | 5 N/A N/A 1537304 C python 526MiB | | 6 N/A N/A 1537304 C python 526MiB | | 7 N/A N/A 1537304 C python 526MiB | +---------------------------------------------------------------------------------------+
Just to follow up: 4s is better than you had mentioned to me offline! Is 4s slow enough to be a problem? Not saying we can't make it faster, but how fast is fast enough?
Description
Time to import jax and get devices: 3.9980173110961914
System info (python version, jaxlib version, accelerator, etc.)