foundation-model-stack / fms-fsdp

🚀 Efficiently (pre)training foundation models with native PyTorch features, including FSDP for training and SDPA implementation of Flash attention v2.
https://pytorch.org/docs/stable/fsdp.html
Apache License 2.0
114 stars 18 forks source link

A write-up on Meta Device Init x Pretraining #64

Open lchu-ibm opened 3 months ago

lchu-ibm commented 3 months ago

Scope

This write-up only applies to "initial model init". For cases that require loading a checkpoint (continue-pretraining, fine-tuning and inference), this is not needed as any init would be overwritten by ckpt. Therefore, this mainly target "pretraining from scratch".

Background on meta device init

There are two ways to leverage meta device to save cpu memory during model init:

  1. create a full model copy on rank0 while put model on meta device on all other ranks.
  2. put model on meta device on all ranks, including rank0.

The first method init a full model on rank0 and utilize sync_module_states=True during FSDP call to broadcast model from rank0 to all other ranks. This saves cpu memory from world_size total copies to only 1 copy. The second method puts model on meta device on all ranks (including rank0), and utilize proper param_init_fn and/or post-FSDP init. Comparing to the first method, the second one not only saves cpu memory (0 copy), but also greatly saves model init time, as this avoids initialing a full model on cpu (for large models like 70b, this could take 20 mins)

Method 2 is both more efficient and better cpu-mem saving, however, it can be very tricky to properly set up for pretraining and it might cause silent problems. Unlike continue-pretraining/fine-tuning/inference where model init isn’t important as it will be overwritten by loaded ckpt, pretraining requires proper model init which is very crucial. And model init for method 2 can be tricky no matter which stage you want to apply init:

pre-FSDP init

This isn't possible with method 2 as all ranks are using meta device before FSDP call. And this is also the reason that method 1 is much safer: you do all you want before the FSDP call as the model was still a full copy sitting on cpu. you can perform any init you need and it will be properly broadcast to other ranks during FSDP call. But again, we want method 2 and we don't want any cpu copy, so we will pass on this.

during-FSDP init

This is achieved by leveraging param_init_fn, which will be performed on "to be materialized modules". Since we need to materialize and put on device first (as full model is on meta device), such param_init_fn is typically something like:

def param_init_fn(module):
    module.to_empty(device)
    module.init  # e.g. module.reset_parameters()

here comes the tricky part where we might get silent problems. param_init_fn will be performed on all to-be-materialized-modules, which pop/deque in a top-down/parent-children fashion (reference). Although this is already a great improvement from old times when we started the work (this has a very great detailed explanations on some old issues which we also observed and had to conquer), yet current design still requires a hidden-user-agreed-contract that "param_init_fn should only initialize module's own parameters/buffers but not any of the sub-modules". Another implicit requirement is we need to have such "init" defined on all possible modules. So what would happen here if we don't follow strictly to the rules here, like what we have now in FMS?
sub-modules would be re-init multiple times. Our reset_parameters() is designed in a way that calling model.reset_parameters() would init the full model with true/desired init. Similarly, Llama_Block.reset_parameters() would init the full block. This is desired as typically we want this single line model wise init. And this works well for method 1. But imagine what would happen here if we use it as param_init_fn: recall the "to be materialized modules" will be something like [LLaMABlock, MultiHeadAttention, Linear, Linear, Linear, etc.], so children modules like "Linear" will be re-init multiple times and this can be problematic:

  1. issues discussed in the reference I shared above.
  2. more importantly: silent problems if we don't provide init all FULL coverage. Again, recall the fact that we defined our "init" on model level (llama.reset_parameters()) and "key module" levels (attn_block, embed, mlp_block, layer_norm) as that was typically sufficient, but these will be "silently" overwritten by lower level modules (e.g. Linear) because basic modules like Linear has their own implementation of reset_parameters(). so during this "re-init" on these "leaf nodes", wrong init will overwrite our true init, and this is silent!

post-FSDP init

This can be more tricky. This is less preferred than using param_init_fn so I am not going into too much details. But trying to do post-FSDP init involves manipulating model params outside forward/backward which you will run into issues like "illegal memory access" as the model is already sharded. And you could technically leverage FSDP.summon_full_params() with some "writebacks" to achieve some, but that is less-efficient and less-user-friendly than leveraging param_init_fn. So this is also not wanted.

what to do with FMS

so it seems "during-FSDP init with param_init_fn" is the way to go, but we would have to meet the contract:

  1. rewrite ALL init (reset_parameters) to be non-recursive.
  2. provide FULL coverage for init.

Is there a way to avoid doing so? and potentially re-use our existing recursive version? Well, the answer is yes, and the trick here turns out to be simple: we just need to add a "filter" to make sure param_init_fn is recursively applied to modules that are mutually exclusive but cover 100% of the params. This way, no re-init would ever happen.

    def param_init_fn(module):

        if (
                # provide the modules that are mutually exclusive but also cover 100% of the model params
                isinstance(module, MultiHeadAttention)
                or isinstance(module, WordEmbedding)
                or isinstance(module, GatedLinearUnit)
                or isinstance(module, LayerNormParameterized)
        ):
            module.to_empty(device=torch.cuda.current_device())
            with torch.no_grad():
                module.reset_parameters()