jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.37k stars 2.79k forks source link

Memory issue when randomly initializing large parameters, sharding cannot help #19893

Open hr0nix opened 8 months ago

hr0nix commented 8 months ago

Description

Consider the following code snippet:

import jax
import flax.linen as nn
from jax.sharding import Mesh
import functools

class Model(nn.Module):
    output_dim = 32768 * 8

    @nn.compact
    def __call__(self, inputs):
        block = nn.Dense(features=self.output_dim, use_bias=False)
        return block(inputs)

class ShardedModel(nn.Module):
    output_dim = 32768 * 8

    @nn.compact
    def __call__(self, inputs):
        init_fn = nn.initializers.lecun_normal()
        # init_fn = nn.initializers.zeros
        block = nn.Dense(
            features=self.output_dim,
            use_bias=False,
            kernel_init=nn.with_logical_partitioning(
                init_fn, ("logical_axis", "unmodelled")
            ),
        )
        return block(inputs)

def test_model(model: nn.Module):
    key = jax.random.PRNGKey(0)
    input_shape = (1, 32768)
    inputs = jax.random.normal(key, input_shape)

    devices = jax.devices()
    mesh = Mesh(devices, {"mesh_axis": len(devices)})
    print(f"Device mesh: {mesh}")
    sharding_rules = [
        ("logical_axis", "mesh_axis"),
    ]

    abstract_params = jax.eval_shape(model.init, key, inputs)
    params_partition_spec = nn.get_partition_spec(abstract_params)
    params_sharding = nn.logical_to_mesh_sharding(
        params_partition_spec,
        mesh,
        rules=sharding_rules,
    )
    print(f"Intended sharding: {params_sharding}")

    init_fn = functools.partial(model.init, key, inputs)
    init_fn = jax.jit(
        init_fn,
        out_shardings=params_sharding,
    )
    params = init_fn()

    actual_sharding = jax.tree_util.tree_map(lambda leaf: leaf.sharding, params)
    print(f"Actual sharding: {actual_sharding}")

def main():
    test_model(Model())
    # test_model(ShardedModel())

if __name__ == "__main__":
    main()

Here I'm trying to initialize a very large dense layer. Despite the layer weights requiring only 32Gb of RAM (and I'm running on 80Gb H100), this code will fail because jax will try to simultaneously allocate quite a few buffers for RNG keys so that the total memory consumption is 112 Gb!

jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 85899347204 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:         0B
              constant allocation:         8B
        maybe_live_out allocation:   32.00GiB
     preallocated temp allocation:   80.00GiB
  preallocated temp fragmentation:       124B (0.00%)
                 total allocation:  112.00GiB
              total fragmentation:   16.00GiB (14.29%)
Peak buffers:
        Buffer 1:
                Size: 32.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/Model/Dense_0/mul" source_file="/usr/local/lib/python3.10/dist-packages/flax/core/scope.py" source_line=968
                XLA Label: fusion
                Shape: f32[32768,262144]
                ==========================

        Buffer 2:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/Model/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/core/scope.py" source_line=968
                XLA Label: custom-call
                Shape: u32[2,2147483648]
                ==========================

        Buffer 3:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/Model/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/core/scope.py" source_line=968
                XLA Label: custom-call
                Shape: u32[2,2147483648]
                ==========================

        Buffer 4:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/Model/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/core/scope.py" source_line=968
                XLA Label: fusion
                Shape: u32[2,2147483648]
                ==========================

        Buffer 5:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/Model/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/core/scope.py" source_line=968
                XLA Label: fusion
                Shape: u32[2,2147483648]
                ==========================

        Buffer 6:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/Model/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/core/scope.py" source_line=968
                XLA Label: fusion
                Shape: u32[2,2147483648]
                ==========================

Is this intended? Do we really need to store these buffers in memory simultaneously to initialize the layer?

In any case, we can try to fix the problem by sharding the layer over the available devices (8x80Gb H100, comment line 66 and uncomment line 67 in the code above). Interestingly, while this change reduces the size of the parameter tensor as intended, rng buffers are still being allocated in full!

jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 103079216656 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:         0B
              constant allocation:        40B
        maybe_live_out allocation:    4.00GiB
     preallocated temp allocation:   96.00GiB
  preallocated temp fragmentation:       124B (0.00%)
                 total allocation:  100.00GiB
              total fragmentation:    4.00GiB (4.00%)
Peak buffers:
        Buffer 1:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/ShardedModel/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/spmd.py" source_line=350
                XLA Label: custom-call
                Shape: u32[2,2147483648]
                ==========================

        Buffer 2:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/ShardedModel/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/spmd.py" source_line=350
                XLA Label: custom-call
                Shape: u32[2,2147483648]
                ==========================

        Buffer 3:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/ShardedModel/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/spmd.py" source_line=350
                XLA Label: fusion
                Shape: u32[2,2147483648]
                ==========================

        Buffer 4:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/ShardedModel/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/spmd.py" source_line=350
                XLA Label: fusion
                Shape: u32[2,2147483648]
                ==========================

        Buffer 5:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/ShardedModel/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/spmd.py" source_line=350
                XLA Label: fusion
                Shape: u32[2,2147483648]
                ==========================

        Buffer 6:
                Size: 16.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/ShardedModel/Dense_0/jit(_truncated_normal)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/spmd.py" source_line=350
                XLA Label: fusion
                Shape: u32[2,2147483648]
                ==========================

        Buffer 7:
                Size: 4.00GiB
                Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/ShardedModel/Dense_0/mul" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/spmd.py" source_line=350
                XLA Label: fusion
                Shape: f32[4096,262144]
                ==========================

This seems to be a bug: why is jax trying to materliaze the full rng tensor on each shard if it's not needed in full there?

Finally, if I use all zeros initialization (uncomment line 22 in the code above), the issue goes away.

So, to summarize, I have the following questions:

The example above, while artificial, is inspired by a real problem that we've encountered while trying to initialize a large model.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.20
jaxlib: 0.4.20
numpy:  1.24.3
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1

$ nvidia-smi
Tue Feb 20 16:53:25 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:8D:00.0 Off |                    0 |
| N/A   34C    P0             117W / 700W |    539MiB / 81559MiB |      1%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  | 00000000:91:00.0 Off |                    0 |
| N/A   30C    P0             113W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  | 00000000:95:00.0 Off |                    0 |
| N/A   33C    P0             112W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  | 00000000:99:00.0 Off |                    0 |
| N/A   30C    P0             115W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  | 00000000:AB:00.0 Off |                    0 |
| N/A   34C    P0             119W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  | 00000000:AF:00.0 Off |                    0 |
| N/A   30C    P0             113W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  | 00000000:B3:00.0 Off |                    0 |
| N/A   33C    P0             112W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  | 00000000:B7:00.0 Off |                    0 |
| N/A   30C    P0             114W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
hr0nix commented 8 months ago

Can confirm that sharding-based solution works if using jax_default_prng_impl=rbg