google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.93k stars 2.74k forks source link

Resharding across MGPUs results in long series of cuMemAlloc_v2 calls #18666

Open andy-kogsys opened 9 months ago

andy-kogsys commented 9 months ago

Description

I'm implementing a time marching simulation across multiple GPUs. The calculation has field arrays sharded in one axis, and operators sharded in another (actual implementation involves ffts constrained using experimental.custom_partitioning, but I've omitted that here for simplicity). I use lowerand compilefor AOT compilation to make it easy for me to benchmark actual runtime.

What I'm seeing is that on first execution a long time is spent executing a series of cuMemAlloc_v2 calls on each device in a series of streams of the form Stream #N(Memset). The time this takes appears to grow with the square of the GPUs. For the minimal example below I see the following:

My questions:

  1. Are these calls expected?
  2. If so, is it expected that they should take so long?

Below is a minimal example. In practice I'm using donate_argnames and specifying in_shardings and out_shardings, but have omitted here for brevity.

Any help would be much appreciated!

Minimal example:

import jax
import jax.numpy as jnp
from jax import jit
from jax.lax import fori_loop
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from time import perf_counter

# inputs
ngpu = 2
dims = (8192, 8192)
nt = 100    

# define a single time step
def run_step(i, carry):

    # unpack fields
    flda, fldb, opa, opb = carry

    # calculations
    fldb = opa * flda  
    flda = opb * fldb

    return (flda, fldb, opa, opb)

# define run function
def run(nt, carry):
    return fori_loop(0, nt, run_step, carry)

# create mesh
devices = mesh_utils.create_device_mesh((ngpu,), jax.devices()[0:ngpu])
mesh = Mesh(devices, axis_names=("gpus",))
shard_y = NamedSharding(mesh, P(None, "gpus"))
shard_x = NamedSharding(mesh, P("gpus", None))

# begin trace
with jax.profiler.trace("./tensorboard"):

    # create operators & fields
    tbeg = perf_counter()
    opa = jax.device_put(jnp.ones(dims), shard_x)
    opb = jax.device_put(jnp.ones(dims), shard_x)       
    flda = jax.device_put(jnp.ones(dims), shard_y)
    fldb = jax.device_put(jnp.ones(dims), shard_y)
    carry = (flda, fldb, opa, opb)
    trun = perf_counter() - tbeg
    print(f"Array creation time:\t{1e3*trun:8.1f} ms")

    # compile 
    tbeg = perf_counter()
    run_jit = jit(run, static_argnames=("nt"))
    lowered = run_jit.lower(nt, carry)
    compiled = lowered.compile()
    trun = perf_counter() - tbeg
    print(f"Compile time:\t\t{1e3*trun:8.1f} ms")

    # single step warmup run
    tbeg = perf_counter()
    _ = run_jit(1, carry)
    flda.block_until_ready()
    trun = perf_counter() - tbeg
    print(f"Single step run time:\t{1e3*trun:8.1f} ms")

    # run the full calculation
    tbeg = perf_counter()
    flda, fldb, _, _ = run_jit(nt, carry)
    flda.block_until_ready()
    trun = perf_counter() - tbeg
    print(f"Run time:\t\t{1e3*trun:8.1f} ms")

Example trace: image image

What jax/jaxlib version are you using?

jax==0.4.20, jaxlib==0.4.20+cuda11.cudnn86

Which accelerator(s) are you using?

GPU (16x Nvidia A100 40GB, but can recreate on 2x)

Additional system info?

1.26.2 3.10.13 | packaged by conda-forge | (main, Oct 26 2023, 18:07:37) [GCC 12.3.0] uname_result(system='Linux', node='a2-mega-v1', release='5.10.0-26-cloud-amd64', version='#1 SMP Debian 5.10.197-1 (2023-09-29)', machine='x86_64')

NVIDIA GPU info

+-----------------------------------------------------------------------------+ | NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA A100-SXM... On | 00000000:00:04.0 Off | 0 | | N/A 32C P0 56W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 1 NVIDIA A100-SXM... On | 00000000:00:05.0 Off | 0 | | N/A 32C P0 55W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 2 NVIDIA A100-SXM... On | 00000000:00:06.0 Off | 0 | | N/A 32C P0 54W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 3 NVIDIA A100-SXM... On | 00000000:00:07.0 Off | 0 | | N/A 34C P0 58W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 4 NVIDIA A100-SXM... On | 00000000:00:08.0 Off | 0 | | N/A 33C P0 56W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 5 NVIDIA A100-SXM... On | 00000000:00:09.0 Off | 0 | | N/A 32C P0 56W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 6 NVIDIA A100-SXM... On | 00000000:00:0A.0 Off | 0 | | N/A 30C P0 53W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 7 NVIDIA A100-SXM... On | 00000000:00:0B.0 Off | 0 | | N/A 32C P0 53W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 8 NVIDIA A100-SXM... On | 00000000:80:00.0 Off | 0 | | N/A 32C P0 57W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 9 NVIDIA A100-SXM... On | 00000000:80:01.0 Off | 0 | | N/A 31C P0 54W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 10 NVIDIA A100-SXM... On | 00000000:80:02.0 Off | 0 | | N/A 32C P0 55W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 11 NVIDIA A100-SXM... On | 00000000:80:03.0 Off | 0 | | N/A 30C P0 53W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 12 NVIDIA A100-SXM... On | 00000000:80:04.0 Off | 0 | | N/A 31C P0 53W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 13 NVIDIA A100-SXM... On | 00000000:80:05.0 Off | 0 | | N/A 32C P0 58W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 14 NVIDIA A100-SXM... On | 00000000:80:06.0 Off | 0 | | N/A 32C P0 55W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+ | 15 NVIDIA A100-SXM... On | 00000000:80:07.0 Off | 0 | | N/A 32C P0 55W / 400W | 2MiB / 40960MiB | 0% Default | | | | Disabled | +-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+

andy-kogsys commented 9 months ago

Any thoughts on this?