Open gianlucadetommaso opened 1 year ago
I think you will have to provide more information regarding the pjit
case. Can you dump the HLO please?
Also if there are no arguments, then there is no need to specify in_shardings
. It's an optional argument.
@yashk2810 Here is something perhaps better. I managed to create an example reproducing the issue.
from transformers import FlaxAutoModelForCausalLM
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.pjit import pjit
from jax import eval_shape, device_put, local_device_count
from copy import deepcopy
import numpy as np
import re
from jax.tree_util import DictKey, FlattenedIndexKey, GetAttrKey, SequenceKey, tree_map_with_path, tree_map
def path_to_string(path, separator: str = None):
keys = []
for key in path:
if isinstance(key, SequenceKey):
keys.append(str(key.idx))
elif isinstance(key, DictKey):
keys.append(str(key.key))
elif isinstance(key, GetAttrKey):
keys.append(str(key.name))
elif isinstance(key, FlattenedIndexKey):
keys.append(str(key.key))
else:
keys.append(str(key))
if separator is None:
return tuple(keys)
return separator.join(keys)
def named_tree_map(f, tree, *rest, is_leaf=None, separator=None):
return tree_map_with_path(
lambda string_path, x, *r: f(path_to_string(string_path, separator=separator), x, *r),
tree,
*rest,
is_leaf=is_leaf,
)
def match_partition_specs(specs, tree):
def get_partition_spec(path, shape_leaf):
if len(shape_leaf.shape) == 0 or np.prod(shape_leaf.shape) == 1:
return PartitionSpec()
for rule, ps in specs.items():
if re.search(rule, path) is not None:
return ps
return PartitionSpec()
return named_tree_map(get_partition_spec, tree, separator="/")
if __name__ == "__main__":
mesh = Mesh(create_device_mesh((1, 1, local_device_count())), ("dp", "fsdp", "mp"))
partition_specs = {
'transformer/wte/embedding': PartitionSpec('mp', 'fsdp'),
'attn/(k_proj|q_proj|v_proj)/kernel': PartitionSpec('fsdp', 'mp'),
'attn/out_proj/kernel': PartitionSpec('mp', 'fsdp'),
'mlp/fc_in/kernel': PartitionSpec('fsdp', 'mp'),
'mlp/fc_in/bias': PartitionSpec('mp'),
'mlp/fc_out/kernel': PartitionSpec('mp', 'fsdp'),
'lm_head/kernel': PartitionSpec('fsdp', 'mp'),
'lm_head/bias': PartitionSpec('mp'),
}
model, params = FlaxAutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", _do_init=False)
params = tree_map(lambda v: v.astype("float16") if (v.dtype == "float32" or v.dtype == "float64") else v, params)
state = {k: deepcopy(params) for k in [1, 2, 3, 4]}
def init_state_fn():
return state
shapes_dtypes = eval_shape(init_state_fn)
partitions = match_partition_specs(partition_specs, shapes_dtypes)
shardings = tree_map(lambda p: NamedSharding(mesh=mesh, spec=p), partitions)
sharded_state = pjit(init_state_fn, out_shardings=shardings)()
# sharded_state = device_put(state, shardings)
print(sharded_state.keys())
As described above, I'm running this across 8 GPU devices. When using pjit
, I get the following warning,
W external/xla/xla/service/hlo_rematerialization.cc:2218] Can't reduce memory use below 11.83GiB (12701564928 bytes) by rematerialization; only reduced to 16.91GiB
and then an error message like Execution of replica 0 failed: INTERNAL: Failed to allocate XXX bytes for new constant
. I don't understand why, since in theory the sharded state should fit in GPU memory.
On the other end, when I comment the line with pjit
and uncomment the line with device_put
, it works as expected.
Any ideas? This is blocking me quite a bit. Thanks!
The error message tells me that it is the state
that you are closing over which is causing a problem.
Try creating the state
inside the pjitted function instead of materializing it outside? This way the state will be materialized as sharded directly.
Also, if device_put works, then why are you using pjit
to do this? device_put
seems like a better solution to me instead of running a XLA computation.
The initial state
is on CPU, not on GPU. I expect pjit
to materialize the state as sharded directly on GPU.
The code above is just a simplification, but eventually I would need to instantiate the state using pjit
. The state will be created using flax.train_state.create
. This will internally instantiate the parameters of an optax
optimizer. While it might be possible to instantiate the whole state on CPU first (currently, I wouldn't know how to do it, as optax directly initializes its parameters on GPU, if possible), it seems like a hassle compared to just using pjit
.
Independently of the best practice for this though, pjit
seems to take far too much memory. I would love to understand why, and how to circumvent this. If pjit
takes so much memory to materialize the state, it might also take a lot of memory when distributing a training step, making it very hard to train large models.
Hi, bringing this up again as I haven't yet managed to solve the issue. Any chance you might have time to look into this?
Here is a common solution:
shard_model_params_to_multihost = lambda params: ... # write your own sharding function here
cpu_device = jax.devices('cpu')[0]
with jax.default_device(cpu_device):
params = load_params(...)
params = shard_model_params_to_multihost(params)
Besides, from my understanding, if the params
are sharded, the train_state
created from the params would also be sharded.
Hi, we are observing the same issue. We have a model that runs fine with pmap but fails at initialization with pjit. We tested this with pjit batch parallelism and ZerO sharding and both fail.
We are doing exactly the same, loading pretrained numpy params and closing over them with pjit. Our function only takes the rng_key
as input and inside the network checks whether params are present in outer pretrained params and does jnp.asarray
otherwise it creates a new param to train from scratch.
With pmap this works fine as it's somehow able to free up the memory, but with pjit simple batch parallelism fails. It compiles fine but then fails with:
XlaRuntimeError: RESOURCE_EXHAUSTED: Error loading program: Attempting to allocate 11.49G. That was not possible. There are 3.59G free.; (1x1x0_HBM0):
@ayaka14732 thanks for the answer. Loading parameters on CPU is not really the bottleneck here - in my example script above, params
is loaded on CPU. The problem arises later when sharding, basically the last line of your little script.
@hamzamerzic if you happen to find a solution, would you mind sharing it here? Thanks!
A workaround that we ended up going for is to create a new function for initializing parameters that takes pretrained parameters as input instead of taking them from the global scope. This way there are no values to close over so pjit should not OOM.
The fact that pjit does not efficiently handle memory of values closed over still worries me though.
Description
Suppose I have a large state stored on CPU, and a tree of shardings matching the state structure. Let's assume the partition is fairly uniform, so that about the same amount of memory should supposedly be allocated per device after sharding.
On my multi-GPU setup, the following code goes out of memory:
while the following does not:
Why? The result should be the same. In both cases, I should obtain a state sharded across my GPU devices.
Some context The example above is a proof of concept. More specifically, I am trying to instantiate on GPU the state of GPT-J, an LLM with 6b parameters. The state is mainly comprised of the model parameters, as well as two more replicas corresponding to the mu and nu parameters of an optax.adam optimizer. So, 18b parameters in total. In half-precision (float16, i.e. 2 bytes), this gives me 36b bytes, that is 36 GB of memory.
My Amazon EC2 instance has 8 GPUs, each with 16 GiB of GPU memory. I have created a sharding tree that partitions the state fairly uniformly across the 8 devices. Since 36 / 8 GB = 4.5 GB, the sharded state should comfortably fit in GPU memory.
And yet, when initializing the state with pjit, I go out of memory. I then stripped down my example to the bone, and discovered, like in the example above, that device_put works instead, as it should.
Do you have an explanation for why this is happening? Is there anything I can do to make pjit work? Eventually, I really want to work with pjit rather than device_put, because I want my state to be directly sharded on GPU, rather than having to store it on CPU first and then putting it on the devices.
What jax/jaxlib version are you using?
jax 0.4.13, jaxlib 0.4.13+cuda11.cudnn86
Which accelerator(s) are you using?
GPU
Additional system info
OS
NVIDIA GPU info