openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.59k stars 404 forks source link

Unable to use residual offloading with scan and remat #17541

Open qGentry opened 2 hours ago

qGentry commented 2 hours ago

Description

Hi guys, I'm very excited with recent activations offloading mechanism introduced in JAX/XLA:GPU but I'm unable to make it work with the scan. My setup is the following - I'm training classic transformer with transformer block scanned over inputs "number of layers" times. I'm also using rematerialization to reduce memory footprint of the model. I basically wrap apply_block function with the jax.remat with "nothing_saveable" policy and then scan this block over inputs to achieve desired behavior - the only activations being saved during forward pass in my case is the residual stream (embeddings) in between scanned block.

With the recent introduction of the "save_and_offload_only_these_names" policy, I thought that it would be enough to mark the output of the scanned block with jax.ad_checkpoint(x, "output") and then specify names_which_can_be_offloaded=["output"], but it didn't work.

I've implemented repro to showcase what is going on:

``` import flax.linen as nn import jax import jax.ad_checkpoint import jax.numpy as jnp import numpy as np from flax.linen.linear import default_kernel_init EMB_DIM = 2048 HID_DIM = 2048 BS = 64 SEQ_LEN = 8192 N_LAYERS = 32 CHECKPOINT_POLICY = jax.checkpoint_policies.save_and_offload_only_these_names( names_which_can_be_saved=[], names_which_can_be_offloaded=[], offload_src="device", offload_dst="pinned_host", ) mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ("data", "model")) input_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec("data", None) ) target_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec( "data", ), ) rules = ( ("batch", "data"), ("embedding", None), ("hidden", "model"), ("q_sequence", "model"), ) class MLP(nn.Module): @nn.compact def __call__(self, x): x_residual = x h = nn.Dense( HID_DIM, kernel_init=nn.with_logical_partitioning( default_kernel_init, ("embedding", "hidden"), ), use_bias=False, )(x) h = jax.ad_checkpoint.checkpoint_name(h, "hidden") h = nn.relu(h) x = nn.Dense( EMB_DIM, kernel_init=nn.with_logical_partitioning( default_kernel_init, ("hidden", "embedding"), ), use_bias=False, )(h) x = x_residual + x # Sequence parallelism x = nn.with_logical_constraint(x, ("batch", "q_sequence", None)) x = jax.ad_checkpoint.checkpoint_name(x, "residual") return x class Output(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense( features=1, kernel_init=nn.with_logical_partitioning( default_kernel_init, ("hidden", None), ), use_bias=False, )(x)[..., 0] x = jnp.mean(x, axis=1) return x class Model(nn.Module): @nn.compact def __call__(self, x): def apply_module(block, block_input, _): block_output = block(block_input) return block_output, None apply_module = nn.remat( apply_module, policy=CHECKPOINT_POLICY, prevent_cse=False, ) x, _ = nn.scan( apply_module, variable_axes={"params": 0}, split_rngs={"params": True}, length=N_LAYERS, metadata_params={nn.PARTITION_NAME: "layers"}, )(MLP(), x, None) preds = Output()(x) return preds def loss_fn(preds, target): return jnp.mean((preds - target) ** 2) def calc_loss(params, inputs, target): preds = Model().apply(params, inputs) loss = loss_fn(preds, target) return loss def train_step(params, inputs, target): loss, grads = jax.value_and_grad(calc_loss)(params, inputs, target) params = jax.tree_util.tree_map(lambda p, g: p - 1e-8 * g, params, grads) return params, loss def unbox_logically_partioned(tree, apply_constraint: bool = True): return jax.tree_util.tree_map( lambda leaf: ( leaf.unbox(apply_constraint=apply_constraint) if isinstance(leaf, nn.LogicallyPartitioned) else leaf ), tree, is_leaf=lambda node: isinstance(node, nn.LogicallyPartitioned), ) def get_gpu_memory_usage() -> dict[str, float]: if jax.default_backend() != "gpu": return {} num_devices = jax.local_device_count("gpu") gpu_memory_usage = [] for i in range(num_devices): memory_stats = jax.local_devices()[i].memory_stats() gpu_memory_usage.append( memory_stats["peak_bytes_in_use"] / memory_stats["bytes_limit"] * 100 ) return {f"GPU{i}": val for i, val in enumerate(gpu_memory_usage)} with mesh, nn.logical_axis_rules(rules): fake_inputs = jnp.empty((BS, SEQ_LEN, EMB_DIM)) fake_inputs = jax.device_put(fake_inputs, input_sharding) fake_target = jnp.empty((BS,)) fake_target = jax.device_put(fake_target, target_sharding) params = Model().init(jax.random.PRNGKey(0), fake_inputs) params = unbox_logically_partioned(params) train_step_fn = ( jax.jit( train_step, in_shardings=( jax.tree_util.tree_map(lambda x: x.sharding, params), input_sharding, target_sharding, ), out_shardings=( jax.tree_util.tree_map(lambda x: x.sharding, params), jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()), ), donate_argnums=(0,), ) .lower(params, fake_inputs, fake_target) .compile() ) jax.ad_checkpoint.print_saved_residuals( train_step, params, fake_inputs, fake_target ) with open("compiled.txt", "w") as f: f.write(train_step_fn.as_text()) memory_analysis = train_step_fn.memory_analysis() print( f"Total size device = {memory_analysis.temp_size_in_bytes / 1024 / 1024 / 1024} GB, " # noqa E501 f"host = {memory_analysis.host_temp_size_in_bytes / 1024 / 1024 / 1024} GB" ) for i in range(10): inputs = jax.random.normal(jax.random.PRNGKey(i), (BS, SEQ_LEN, EMB_DIM)) inputs = jax.device_put(inputs, input_sharding) target = jax.random.normal(jax.random.PRNGKey(0), (BS,)) target = jax.device_put(target, target_sharding) params, loss = train_step_fn(params, inputs, target) print(loss) print(get_gpu_memory_usage()) ```

First of all, I wanted to ensure that offloading is working in the first place. With

    names_which_can_be_saved=[],
    names_which_can_be_offloaded=[],

I'm getting following results: Total size device = 20.26562874764204 GB, host = 0.0 GB Quite reasonable value.

then, I wanted to check how much would it cost to save "h" on GPU, so I set

    names_which_can_be_saved=["hidden"],
    names_which_can_be_offloaded=[],

and getting Total size device = 35.2968789935112 GB, host = 0.0 GB This is also totally expected as "h" is f32[32,64,8192,2048] sharded across 8 GPUs which is equals to 16GB per GPU.

Ok, let's try to offload "h" and see what happens.

    names_which_can_be_saved=[],
    names_which_can_be_offloaded=["hidden"],

Total size device = 19.75000447779894 GB, host = 16.0 GB - also totally expected, instead of saving 16GB on GPU, we're offloading activations on host, device memory saved. Also iterations become a lot slower with is also expected.

Now we sure that offloading is indeed working properly, I've tried to offload "residual" tensor (output of the scanned block).

    names_which_can_be_saved=[],
    names_which_can_be_offloaded=["residual"],

Aaaand nothing happens - Total size device = 20.26562874764204 GB, host = 0.0 GB, nothing happens, no changes in memory usage, iterations is the same as no offloading at all.

System info (python version, jaxlib version, accelerator, etc.)

Python 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax:    0.4.33
jaxlib: 0.4.33
numpy:  1.24.3
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='end-llm-computeinstance-e00yhypr7caccaxmct.priv.hw.nebius.yt', release='5.15.0-119-generic', version='#129-Ubuntu SMP Fri Aug 2 19:25:20 UTC 2024', machine='x86_64')

$ nvidia-smi
Tue Sep 24 10:15:19 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:8D:00.0 Off |                    0 |
| N/A   34C    P0            114W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:91:00.0 Off |                    0 |
| N/A   31C    P0            117W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  |   00000000:95:00.0 Off |                    0 |
| N/A   36C    P0            119W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  |   00000000:99:00.0 Off |                    0 |
| N/A   30C    P0            113W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  |   00000000:AB:00.0 Off |                    0 |
| N/A   34C    P0            118W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  |   00000000:AF:00.0 Off |                    0 |
| N/A   31C    P0            116W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  |   00000000:B3:00.0 Off |                    0 |
| N/A   35C    P0            115W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  |   00000000:B7:00.0 Off |                    0 |
| N/A   30C    P0            114W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+

JAX issue https://github.com/jax-ml/jax/issues/23869

qGentry commented 2 hours ago

I've also tried following implementation, inspired by https://github.com/jax-ml/jax/issues/23614#issuecomment-2350773816 with wrapping entire apply_module with flax's custom_vjp, but it doesn't work properly.

``` class Model(nn.Module): @nn.compact def __call__(self, x): def apply_module(block, block_input, _): block_output = block(block_input) return block_output, None def apply_module_fwd(block, block_input, _): res, vjp_fn = nn.vjp(apply_module, block, block_input, _) emb, _ = res emb = jax.device_put(emb, TransferToMemoryKind("pinned_host")) return (emb, None), vjp_fn def apply_module_bwd(vjp_fn, res): emb, _ = res emb = jax.device_put(emb, TransferToMemoryKind("device")) res = (emb, None) return vjp_fn(res) apply_module_vjp = nn.custom_vjp( apply_module, forward_fn=apply_module_fwd, backward_fn=apply_module_bwd ) apply_module_vjp = nn.remat( apply_module_vjp, policy=CHECKPOINT_POLICY, prevent_cse=False, ) x, _ = nn.scan( apply_module_vjp, variable_axes={"params": 0}, split_rngs={"params": True}, length=N_LAYERS, metadata_params={nn.PARTITION_NAME: "layers"}, )(MLP(), x, None) x = jax.device_put(x, TransferToMemoryKind("device")) preds = Output()(x) return preds ```

Total size device = 21.00781624764204 GB, host = 0.5 GB.

Here's part of the trace

Screenshot 2024-09-24 at 14 53 29

As you can see, activations are indeed being offloaded during forward pass, but during forward pass they are not loaded back to devices - looks like these offloaded activations are immediately dropped on CPU and GPU activations are saved instead. That's why host memory is only 0.5GB - it is only reserved for activations of one layer.

Also I've noticed that this approach produces wrong loss & grad calculations, but if I'm commenting out all of the "jax.device_put" transfers, everything works as expected again.