jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.98k stars 2.75k forks source link

pjit uses too much memory #16679

Open gianlucadetommaso opened 1 year ago

gianlucadetommaso commented 1 year ago

Description

Suppose I have a large state stored on CPU, and a tree of shardings matching the state structure. Let's assume the partition is fairly uniform, so that about the same amount of memory should supposedly be allocated per device after sharding.

On my multi-GPU setup, the following code goes out of memory:

from jax.experimental.pjit import pjit
state = pjit(lambda: state, in_shardings=(), out_shardings=shardings)()

while the following does not:

from jax import device_put
state = device_put(state, shardings)

Why? The result should be the same. In both cases, I should obtain a state sharded across my GPU devices.

Some context The example above is a proof of concept. More specifically, I am trying to instantiate on GPU the state of GPT-J, an LLM with 6b parameters. The state is mainly comprised of the model parameters, as well as two more replicas corresponding to the mu and nu parameters of an optax.adam optimizer. So, 18b parameters in total. In half-precision (float16, i.e. 2 bytes), this gives me 36b bytes, that is 36 GB of memory.

My Amazon EC2 instance has 8 GPUs, each with 16 GiB of GPU memory. I have created a sharding tree that partitions the state fairly uniformly across the 8 devices. Since 36 / 8 GB = 4.5 GB, the sharded state should comfortably fit in GPU memory.

And yet, when initializing the state with pjit, I go out of memory. I then stripped down my example to the bone, and discovered, like in the example above, that device_put works instead, as it should.

Do you have an explanation for why this is happening? Is there anything I can do to make pjit work? Eventually, I really want to work with pjit rather than device_put, because I want my state to be directly sharded on GPU, rather than having to store it on CPU first and then putting it on the devices.

What jax/jaxlib version are you using?

jax 0.4.13, jaxlib 0.4.13+cuda11.cudnn86

Which accelerator(s) are you using?

GPU

Additional system info

OS

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  On   | 00000000:00:17.0 Off |                    0 |
| N/A   32C    P0    43W / 300W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  On   | 00000000:00:18.0 Off |                    0 |
| N/A   33C    P0    45W / 300W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  On   | 00000000:00:19.0 Off |                    0 |
| N/A   30C    P0    42W / 300W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  On   | 00000000:00:1A.0 Off |                    0 |
| N/A   31C    P0    43W / 300W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   4  Tesla V100-SXM2...  On   | 00000000:00:1B.0 Off |                    0 |
| N/A   33C    P0    43W / 300W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   5  Tesla V100-SXM2...  On   | 00000000:00:1C.0 Off |                    0 |
| N/A   32C    P0    42W / 300W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   6  Tesla V100-SXM2...  On   | 00000000:00:1D.0 Off |                    0 |
| N/A   30C    P0    43W / 300W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   7  Tesla V100-SXM2...  On   | 00000000:00:1E.0 Off |                    0 |
| N/A   31C    P0    41W / 300W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
yashk2810 commented 1 year ago

I think you will have to provide more information regarding the pjit case. Can you dump the HLO please?

Also if there are no arguments, then there is no need to specify in_shardings. It's an optional argument.

gianlucadetommaso commented 1 year ago

@yashk2810 Here is something perhaps better. I managed to create an example reproducing the issue.

from transformers import FlaxAutoModelForCausalLM
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.pjit import pjit
from jax import eval_shape, device_put, local_device_count
from copy import deepcopy
import numpy as np
import re
from jax.tree_util import DictKey, FlattenedIndexKey, GetAttrKey, SequenceKey, tree_map_with_path, tree_map

def path_to_string(path, separator: str = None):
    keys = []
    for key in path:
        if isinstance(key, SequenceKey):
            keys.append(str(key.idx))
        elif isinstance(key, DictKey):
            keys.append(str(key.key))
        elif isinstance(key, GetAttrKey):
            keys.append(str(key.name))
        elif isinstance(key, FlattenedIndexKey):
            keys.append(str(key.key))
        else:
            keys.append(str(key))
    if separator is None:
        return tuple(keys)
    return separator.join(keys)

def named_tree_map(f, tree, *rest, is_leaf=None, separator=None):
    return tree_map_with_path(
        lambda string_path, x, *r: f(path_to_string(string_path, separator=separator), x, *r),
        tree,
        *rest,
        is_leaf=is_leaf,
    )

def match_partition_specs(specs, tree):
    def get_partition_spec(path, shape_leaf):
        if len(shape_leaf.shape) == 0 or np.prod(shape_leaf.shape) == 1:
            return PartitionSpec()
        for rule, ps in specs.items():
            if re.search(rule, path) is not None:
                return ps
        return PartitionSpec()
    return named_tree_map(get_partition_spec, tree, separator="/")

if __name__ == "__main__":
    mesh = Mesh(create_device_mesh((1, 1, local_device_count())), ("dp", "fsdp", "mp"))
    partition_specs = {
        'transformer/wte/embedding': PartitionSpec('mp', 'fsdp'),
        'attn/(k_proj|q_proj|v_proj)/kernel': PartitionSpec('fsdp', 'mp'),
        'attn/out_proj/kernel': PartitionSpec('mp', 'fsdp'),
        'mlp/fc_in/kernel': PartitionSpec('fsdp', 'mp'),
        'mlp/fc_in/bias': PartitionSpec('mp'),
        'mlp/fc_out/kernel': PartitionSpec('mp', 'fsdp'),
        'lm_head/kernel': PartitionSpec('fsdp', 'mp'),
        'lm_head/bias': PartitionSpec('mp'),
    }

    model, params = FlaxAutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", _do_init=False)
    params = tree_map(lambda v: v.astype("float16") if (v.dtype == "float32" or v.dtype == "float64") else v, params)
    state = {k: deepcopy(params) for k in [1, 2, 3, 4]}

    def init_state_fn():
        return state

    shapes_dtypes = eval_shape(init_state_fn)
    partitions = match_partition_specs(partition_specs, shapes_dtypes)
    shardings = tree_map(lambda p: NamedSharding(mesh=mesh, spec=p), partitions)

    sharded_state = pjit(init_state_fn, out_shardings=shardings)()
    # sharded_state = device_put(state, shardings)
    print(sharded_state.keys())

As described above, I'm running this across 8 GPU devices. When using pjit, I get the following warning,

W external/xla/xla/service/hlo_rematerialization.cc:2218] Can't reduce memory use below 11.83GiB (12701564928 bytes) by rematerialization; only reduced to 16.91GiB

and then an error message like Execution of replica 0 failed: INTERNAL: Failed to allocate XXX bytes for new constant. I don't understand why, since in theory the sharded state should fit in GPU memory.

On the other end, when I comment the line with pjit and uncomment the line with device_put, it works as expected.

Any ideas? This is blocking me quite a bit. Thanks!

yashk2810 commented 1 year ago

The error message tells me that it is the state that you are closing over which is causing a problem.

Try creating the state inside the pjitted function instead of materializing it outside? This way the state will be materialized as sharded directly.

Also, if device_put works, then why are you using pjit to do this? device_put seems like a better solution to me instead of running a XLA computation.

gianlucadetommaso commented 1 year ago

The initial state is on CPU, not on GPU. I expect pjit to materialize the state as sharded directly on GPU.

The code above is just a simplification, but eventually I would need to instantiate the state using pjit. The state will be created using flax.train_state.create. This will internally instantiate the parameters of an optax optimizer. While it might be possible to instantiate the whole state on CPU first (currently, I wouldn't know how to do it, as optax directly initializes its parameters on GPU, if possible), it seems like a hassle compared to just using pjit.

Independently of the best practice for this though, pjit seems to take far too much memory. I would love to understand why, and how to circumvent this. If pjit takes so much memory to materialize the state, it might also take a lot of memory when distributing a training step, making it very hard to train large models.

gianlucadetommaso commented 1 year ago

Hi, bringing this up again as I haven't yet managed to solve the issue. Any chance you might have time to look into this?

ayaka14732 commented 1 year ago

Here is a common solution:

shard_model_params_to_multihost = lambda params: ...  # write your own sharding function here
cpu_device = jax.devices('cpu')[0]
with jax.default_device(cpu_device):
    params = load_params(...)
params = shard_model_params_to_multihost(params)

Besides, from my understanding, if the params are sharded, the train_state created from the params would also be sharded.

hamzamerzic commented 1 year ago

Hi, we are observing the same issue. We have a model that runs fine with pmap but fails at initialization with pjit. We tested this with pjit batch parallelism and ZerO sharding and both fail.

We are doing exactly the same, loading pretrained numpy params and closing over them with pjit. Our function only takes the rng_key as input and inside the network checks whether params are present in outer pretrained params and does jnp.asarray otherwise it creates a new param to train from scratch.

With pmap this works fine as it's somehow able to free up the memory, but with pjit simple batch parallelism fails. It compiles fine but then fails with:

XlaRuntimeError: RESOURCE_EXHAUSTED: Error loading program: Attempting to allocate 11.49G. That was not possible. There are 3.59G free.; (1x1x0_HBM0):

gianlucadetommaso commented 1 year ago

@ayaka14732 thanks for the answer. Loading parameters on CPU is not really the bottleneck here - in my example script above, params is loaded on CPU. The problem arises later when sharding, basically the last line of your little script.

@hamzamerzic if you happen to find a solution, would you mind sharing it here? Thanks!

hamzamerzic commented 1 year ago

A workaround that we ended up going for is to create a new function for initializing parameters that takes pretrained parameters as input instead of taking them from the global scope. This way there are no values to close over so pjit should not OOM.

The fact that pjit does not efficiently handle memory of values closed over still worries me though.