jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Apache License 2.0
30.38k stars 2.79k forks source link

Latency Hiding Scheduler leads to x5 memory usage if used without jax.lax.scan #20763

Open qGentry opened 6 months ago

qGentry commented 6 months ago


Hi, we're training large (300B, 60 layers) mixture of experts transformer on a 1000+ GPU. We have some non-uniformity in layers so we can't use jax.lax.scan directly to stack layers together - instead, we just call each layer independently. Model doesn't have completely random structure, it is like (3 layers with same structure, 1 with another) repeated 15 times (to achieve 60 layers in total) We would benefit a LOT from overlapping computations & communications but when we try to enable latency hiding scheduler --xla_gpu_enable_latency_hiding_scheduler, this leads to increased in memory usage by a factor of 4-5 (from 50Gb per GPU to 200-250Gb per GPU, which is completely unusable). My guess is that compiler doesn't reuse buffers for async comms in this case for different layers.

We've tested also variant with jax.lax.scan and uniform layers, it seemed to work okay from memory usage point of view - only 20-25% overhead from latency hiding scheduler.

Is this a known problem? Is these any workaround?

System info (python version, jaxlib version, accelerator, etc.)

tested on 0.4.25/0.4.26, 1000+ H100 GPU

qGentry commented 2 weeks ago

Here is some toy repro tested on JAX 0.4.34

``` import flax.linen as nn import jax import jax.ad_checkpoint import jax.numpy as jnp import numpy as np from flax.linen.linear import default_kernel_init EMB_DIM = 8192 HID_DIM = 8192 BS = 32 SEQ_LEN = 4096 N_LAYERS = 32 SCAN = False CHECKPOINT_POLICY = jax.checkpoint_policies.save_and_offload_only_these_names( names_which_can_be_saved=[], names_which_can_be_offloaded=[], offload_src="device", offload_dst="pinned_host", ) mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(8, 1), ("data", "model")) input_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec("data", None) ) target_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec( "data", ), ) rules = ( ("batch", "data"), ("embedding", None), ("hidden", "model"), ("q_sequence", "model"), ) class MLP(nn.Module): @nn.compact def __call__(self, x): x_residual = x h = nn.Dense( HID_DIM, kernel_init=nn.with_logical_partitioning( default_kernel_init, ("embedding", "hidden"), ), use_bias=False, )(x) h = nn.relu(h) x = nn.Dense( EMB_DIM, kernel_init=nn.with_logical_partitioning( default_kernel_init, ("hidden", "embedding"), ), use_bias=False, )(h) x = x_residual + x # Sequence parallelism x = nn.with_logical_constraint(x, ("batch", "q_sequence", None)) return x class Output(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense( features=1, kernel_init=nn.with_logical_partitioning( default_kernel_init, ("hidden", None), ), use_bias=False, )(x)[..., 0] x = jnp.mean(x, axis=1) return x class Model(nn.Module): @nn.compact def __call__(self, x): def apply_module(block, block_input, _): block_output = block(block_input) return block_output, None apply_module = nn.remat( apply_module, policy=CHECKPOINT_POLICY, prevent_cse=False, ) if SCAN: x, _ = nn.scan( apply_module, variable_axes={"params": 0}, split_rngs={"params": True}, length=N_LAYERS, metadata_params={nn.PARTITION_NAME: "layers"}, )(MLP(), x, None) else: for i in range(N_LAYERS): x = MLP(name=f"block_{i}")(x) preds = Output()(x) return preds def loss_fn(preds, target): return jnp.mean((preds - target) ** 2) def calc_loss(params, inputs, target): preds = Model().apply(params, inputs) loss = loss_fn(preds, target) return loss def train_step(params, inputs, target): loss, grads = jax.value_and_grad(calc_loss)(params, inputs, target) params = jax.tree_util.tree_map(lambda p, g: p - 1e-8 * g, params, grads) return params, loss def unbox_logically_partioned(tree, apply_constraint: bool = True): return jax.tree_util.tree_map( lambda leaf: ( leaf.unbox(apply_constraint=apply_constraint) if isinstance(leaf, nn.LogicallyPartitioned) else leaf ), tree, is_leaf=lambda node: isinstance(node, nn.LogicallyPartitioned), ) def get_gpu_memory_usage() -> dict[str, float]: if jax.default_backend() != "gpu": return {} num_devices = jax.local_device_count("gpu") gpu_memory_usage = [] for i in range(num_devices): memory_stats = jax.local_devices()[i].memory_stats() gpu_memory_usage.append( memory_stats["peak_bytes_in_use"] / memory_stats["bytes_limit"] * 100 ) return {f"GPU{i}": val for i, val in enumerate(gpu_memory_usage)} with mesh, nn.logical_axis_rules(rules): fake_inputs = jnp.empty((BS, SEQ_LEN, EMB_DIM)) fake_inputs = jax.device_put(fake_inputs, input_sharding) fake_target = jnp.empty((BS,)) fake_target = jax.device_put(fake_target, target_sharding) params = Model().init(jax.random.PRNGKey(0), fake_inputs) params = unbox_logically_partioned(params) train_step_fn = ( jax.jit( train_step, in_shardings=( jax.tree_util.tree_map(lambda x: x.sharding, params), input_sharding, target_sharding, ), out_shardings=( jax.tree_util.tree_map(lambda x: x.sharding, params), jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()), ), donate_argnums=(0,), ) .lower(params, fake_inputs, fake_target) .compile() ) with open("compiled.txt", "w") as f: f.write(train_step_fn.as_text()) memory_analysis = train_step_fn.memory_analysis() print( f"Total size device = {memory_analysis.temp_size_in_bytes / 1024 / 1024 / 1024} GB, " # noqa E501 f"weights = {memory_analysis.argument_size_in_bytes / 1024 / 1024 / 1024} GB, " f"total: {(memory_analysis.argument_size_in_bytes + memory_analysis.temp_size_in_bytes) / 1024 / 1024 / 1024} GB" ) for i in range(10): inputs = jax.random.normal(jax.random.PRNGKey(i), (BS, SEQ_LEN, EMB_DIM)) inputs = jax.device_put(inputs, input_sharding) target = jax.random.normal(jax.random.PRNGKey(0), (BS,)) target = jax.device_put(target, target_sharding) if i == 3: jax.tree_map(lambda x: x.block_until_ready(), params) jax.profiler.start_trace("./profile", create_perfetto_trace=True) params, loss = train_step_fn(params, inputs, target) if i == 3: jax.tree_map(lambda x: x.block_until_ready(), params) jax.profiler.stop_trace() print(loss) print(get_gpu_memory_usage()) ```

Latency hiding scheduler enabled, XLA_FLAGS: --xla_gpu_graph_level=0 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_command_buffer= --xla_gpu_enable_latency_hiding_scheduler=true

SCAN=False (hits OOM):

2024-10-17 10:47:03.275923: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 53.46GiB (57397316392 bytes) by rematerialization; only reduced to 70.25GiB (75430461976 bytes), down from 75.50GiB (81067606600 bytes) originally
Total size device = 54.500054121017456 GB, weights = 16.500030532479286 GB, total: 71.00008465349674 GB


Total size device = 34.5000324845314 GB, weights = 16.500030532479286 GB, total: 51.00006301701069 GB

Latency hiding scheduler disabled, XLA_FLAGS=--xla_gpu_graph_level=0 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_command_buffer=


Total size device = 37.50002360343933 GB, weights = 16.500030532479286 GB, total: 54.00005413591862 GB


Total size device = 35.00000220537186 GB, weights = 16.500030532479286 GB, total: 51.50003273785114 GB
sfvaroglu commented 2 days ago

@qGentry Using JAX 0.4.35 XLA_FLAGS="--xla_gpu_graph_level=0 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_command_buffer= " and SCAN=False, I'm seeing a failure.

*** Check failure stack trace: ***
    @     0x7f26b3b96dc4  absl::lts_20230802::log_internal::LogMessage::SendToLog()
    @     0x7f26b3b96c34  absl::lts_20230802::log_internal::LogMessage::Flush()
    @     0x7f26b3b971e9  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7f26ac056141  xla::PjRtStreamExecutorLoadedExecutable::Execute()
    @     0x7f26abfa5d71  pjrt::PJRT_LoadedExecutable_Execute()
    @     0x7f26bbd699fc  xla::PjRtCApiLoadedExecutable::Execute()
    @     0x7f26c1ab0a25  xla::ifrt::PjRtLoadedExecutable::Execute()
    @     0x7f26c1250c49  xla::(anonymous namespace)::ExecuteShardedOnLocalDevicesInternal<>()
    @     0x7f26c12528ee  xla::PyLoadedExecutable::ExecuteSharded()
    @     0x7f26bbc2fc55  xla::ValueOrThrowWrapper<>::operator()()
    @     0x7f26bbc2fabd  nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
    @     0x7f26c1a86eb8  nanobind::detail::nb_func_vectorcall_complex()
    @     0x56227ff6aabb  _PyEval_EvalFrameDefault
Aborted (core dumped)

Any chance you have other flags or env variables set?