foundation-model-stack / fms-fsdp

🚀 Efficiently (pre)training foundation models with native PyTorch features, including FSDP for training and SDPA implementation of Flash attention v2.
https://pytorch.org/docs/stable/fsdp.html
Apache License 2.0
162 stars 27 forks source link

Revert low_cpu_fsdp implementation #6

Closed lchu-ibm closed 7 months ago

lchu-ibm commented 7 months ago

we had two implementations of low_cpu_fsdp:

  1. load full model on rank0, and meta device on all other ranks. then we use sync_module_states to broadcast the weights.
  2. use meta device on all ranks. and random init it in FSDP call. and post-load the state dict.

we switched from 1 to 2 when we switched from FSDP to 2d (TP + FSDP), as TP's parallelize_module does not like implementation 1 due to variables on different devices during tp-parallelize.

However, as we cut TP and 2d in this open source version of training, we want to revert this and use the first implementation which will make post_init easier.

lchu-ibm commented 7 months ago

Now come to think about this, we might want to keep implementation 2 -

When start from scratch, we have

    if start_step == 0:
        print("Starting from scratch - initializing parameters")
        model.reset_parameters()

as a post-init to guarantee the true init

And when continue training, we have checkpointer to load from ckpt.

lchu-ibm commented 7 months ago

Latest update:

This is what we should do:

Default to meta device on ALL gpus, as most jobs (inference/fine-tuning/contiue-pretraining) would need to load ckpt, thus we can simply rely on FSDP's random init because they will be overwritten anyway. The only exception is: when doing pretraining from scratch on step 0, so we can simply add an if clause to capture this.

There is a block on this though - when initializing model, we don't know yet if we are doing a scratch- pretraining or continue- pretraining. So what mentioned above (the optimal solution) needs more work and we can, for now, default to rank==0 init.