I'm implementing a time marching simulation across multiple GPUs. The calculation has field arrays sharded in one axis, and operators sharded in another (actual implementation involves ffts constrained using experimental.custom_partitioning, but I've omitted that here for simplicity). I use lowerand compilefor AOT compilation to make it easy for me to benchmark actual runtime.
What I'm seeing is that on first execution a long time is spent executing a series of cuMemAlloc_v2 calls on each device in a series of streams of the form Stream #N(Memset). The time this takes appears to grow with the square of the GPUs. For the minimal example below I see the following:
2 GPU first call takes 2.2 s
4 GPU first call takes 7.8 s
8 GPU first call takes 29 s
16 GPU first call takes 118 s
My questions:
Are these calls expected?
If so, is it expected that they should take so long?
Below is a minimal example. In practice I'm using donate_argnames and specifying in_shardings and out_shardings, but have omitted here for brevity.
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
Description
I'm implementing a time marching simulation across multiple GPUs. The calculation has field arrays sharded in one axis, and operators sharded in another (actual implementation involves ffts constrained using
experimental.custom_partitioning
, but I've omitted that here for simplicity). I uselower
andcompile
for AOT compilation to make it easy for me to benchmark actual runtime.What I'm seeing is that on first execution a long time is spent executing a series of cuMemAlloc_v2 calls on each device in a series of streams of the form
Stream #N(Memset)
. The time this takes appears to grow with the square of the GPUs. For the minimal example below I see the following:My questions:
Below is a minimal example. In practice I'm using
donate_argnames
and specifyingin_shardings
andout_shardings
, but have omitted here for brevity.Any help would be much appreciated!
Minimal example:
Example trace:
What jax/jaxlib version are you using?
jax==0.4.20, jaxlib==0.4.20+cuda11.cudnn86
Which accelerator(s) are you using?
GPU (16x Nvidia A100 40GB, but can recreate on 2x)
Additional system info?
1.26.2 3.10.13 | packaged by conda-forge | (main, Oct 26 2023, 18:07:37) [GCC 12.3.0] uname_result(system='Linux', node='a2-mega-v1', release='5.10.0-26-cloud-amd64', version='#1 SMP Debian 5.10.197-1 (2023-09-29)', machine='x86_64')
NVIDIA GPU info
+-----------------------------------------------------------------------------+ | NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA A100-SXM... On | 00000000:00:04.0 Off | 0 | | N/A 32C P0 56W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 1 NVIDIA A100-SXM... On | 00000000:00:05.0 Off | 0 | | N/A 32C P0 55W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 2 NVIDIA A100-SXM... On | 00000000:00:06.0 Off | 0 | | N/A 32C P0 54W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 3 NVIDIA A100-SXM... On | 00000000:00:07.0 Off | 0 | | N/A 34C P0 58W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 4 NVIDIA A100-SXM... On | 00000000:00:08.0 Off | 0 | | N/A 33C P0 56W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 5 NVIDIA A100-SXM... On | 00000000:00:09.0 Off | 0 | | N/A 32C P0 56W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 6 NVIDIA A100-SXM... On | 00000000:00:0A.0 Off | 0 | | N/A 30C P0 53W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 7 NVIDIA A100-SXM... On | 00000000:00:0B.0 Off | 0 | | N/A 32C P0 53W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 8 NVIDIA A100-SXM... On | 00000000:80:00.0 Off | 0 | | N/A 32C P0 57W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 9 NVIDIA A100-SXM... On | 00000000:80:01.0 Off | 0 | | N/A 31C P0 54W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 10 NVIDIA A100-SXM... On | 00000000:80:02.0 Off | 0 | | N/A 32C P0 55W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 11 NVIDIA A100-SXM... On | 00000000:80:03.0 Off | 0 | | N/A 30C P0 53W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 12 NVIDIA A100-SXM... On | 00000000:80:04.0 Off | 0 | | N/A 31C P0 53W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 13 NVIDIA A100-SXM... On | 00000000:80:05.0 Off | 0 | | N/A 32C P0 58W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 14 NVIDIA A100-SXM... On | 00000000:80:06.0 Off | 0 | | N/A 32C P0 55W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 15 NVIDIA A100-SXM... On | 00000000:80:07.0 Off | 0 | | N/A 32C P0 55W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+