google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.1k stars 2.66k forks source link

Unexpected scaling of sharding/pmap #21411

Open lockwo opened 1 month ago

lockwo commented 1 month ago

Description

If I were to run a constant size function and scale it linearly across more and more CPUs/GPUs (just running the same function and the same input again on another device), I would expect the time to execute to stay roughly the same since each CPU is just doing the same computation. There would be some overhead due potentially due to arrays moving around in memory/dispact/collecting but if my arrays are O(10 floats) this would be tiny I imagine. However, if I do this in jax with pmap or sharding on CPU or GPU I see that the time noticeably increases with more devices. Is this something to do with how jax manages distribution or something I am doing wrong?

Here is the example code:

import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import multiprocessing
import numpy as np
import os
import time

import jax.experimental.mesh_utils as mesh_utils
import jax.sharding as jshard

devices_to_use = multiprocessing.cpu_count()

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(devices_to_use)
print(jax.devices())

def vectorized_solve(y0, v):
    return jax.vmap(solve, in_axes=(0, None))(y0, v)

def solve(init, var):
    return jnp.mean(init @ var)

n_per_device = 100
problem_size = 20
all_devices_to_use = [i for i in range(1, 7)]

all_total_traj = []
all_timings_aux = []
all_timings = []
y0s = []

vectorized_jit = jax.jit(vectorized_solve)

key = jax.random.PRNGKey(42)
reps = 500

for devices_to_use in all_devices_to_use:

    devices = jax.devices("gpu")[:devices_to_use]
    print(len(devices))
    device_mesh = mesh_utils.create_device_mesh((len(devices), 1), devices)
    sharding = jshard.PositionalSharding(devices).reshape((len(devices), 1))
    sharding_replicate = sharding.replicate()

    n_traj_total = len(devices) * n_per_device

    y_0 = jnp.ones((n_traj_total, problem_size))
    y_0 = y_0 + jax.random.uniform(key, shape=y_0.shape)

    y0_shard = jax.device_put(y_0, sharding).block_until_ready()

    var = jax.random.uniform(key, shape=(problem_size, 5))
    var_shard = jax.device_put(var, sharding_replicate).block_until_ready()

    _ = vectorized_jit(y0_shard, var_shard).block_until_ready()
    tot = 0
    for i in range(reps):
        start_time = time.time()
        results = vectorized_jit(y0_shard, var_shard).block_until_ready()
        end_time = time.time()
        tot += (end_time - start_time)
    tot /= reps
    all_timings.append(tot)

    pmap = jax.pmap(vectorized_solve, in_axes=(0, None), devices=devices)
    y_0 = y_0.reshape((len(devices), n_per_device, problem_size))
    _ = pmap(y_0, var).block_until_ready()

    tot = 0
    for i in range(reps):
        start_time = time.time()
        results = pmap(y_0, var).block_until_ready()
        end_time = time.time()
        tot += (end_time - start_time)
    tot /= reps

    all_timings_aux.append(tot)    

fig, axs = plt.subplots()
axs.plot(all_devices_to_use, all_timings, linestyle="--", marker="o", label="shard eval")
axs.plot(all_devices_to_use, all_timings_aux, linestyle="--", marker="o", label="pmap eval")
axs.legend()
axs.set_xlabel("Number of devices")
axs.set_ylabel("Total Time")
axs.set_yscale("log")
fig.tight_layout()
plt.show()

GPU:

Screenshot 2024-05-23 at 7 01 09 PM

CPU:

Screenshot 2024-05-23 at 7 01 26 PM

In my mind, these lines should be almost flat. The arrays getting returned are very small and the arrays getting sharded are also very small.

As an even more trivial example, I took the code from the docs on sharding. If I do 8 arrays on 8 devices it is slower than doing 4 arrays on 4 devices (by a noticeable margin). This I find to be very surprising.

Screenshot 2024-05-23 at 8 00 38 PM

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (6 total, 6 local): [cuda(id=0) cuda(id=1) ... cuda(id=4) cuda(id=5)]
process_count: 1

release='5.15.0-89-generic', version='#99-Ubuntu SMP Mon Oct 30 20:42:41 UTC 2023', machine='x86_64')
justinjfu commented 1 month ago

My hypothesis is that in this case the main bottleneck is communication and not compute, since your input sizes are so small. Communication cost scales with the number of devices so the more devices there are the slower the operation is - e.g. if the devices are connected in a ring topology it would take O(# devices) time to perform an all-reduce to compute the mean over all values.

If you scale up the size of the inputs / perform more computation per device the problem would likely become compute bound and the problem would go away. For example, with the sharding docs example you pasted below, running on a pod of 8 TPUs gets you the following numbers:

Input size: (1024 8, 1024 8) 8 devices 5 loops, best of 5: 510 µs per loop 4 devices 5 loops, best of 5: 662 µs per loop

Input size: (1024 8 10, 1024 * 8) 8 devices on 5 loops, best of 5: 2.39 ms per loop 4 devices 5 loops, best of 5: 4.41 ms per loop

As you can see, using 4 devices takes about twice as long as using 8 devices, which is expected.