huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
129.35k stars 25.64k forks source link

Stuck on Initializing Transformers Model with FSDP (Fully Sharded Data Parallel) using meta device #31278

Open jiangjiadi opened 1 month ago

jiangjiadi commented 1 month ago

System Info

Who can help?

text model: @ArthurZucker and @younesbelkada

Information

Tasks

Reproduction

Run Command: torchrun --nproc_per_node 2 test_fsdp.py

import torch
import os
import torch.distributed as dist
import functools
import transformers
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.qwen2 import  Qwen2Config, Qwen2ForCausalLM
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer

local_rank = int(os.environ["LOCAL_RANK"])
print("local_rank:", local_rank)
torch.cuda.set_device(local_rank)

dist.init_process_group("nccl", init_method="env://")

config = Qwen2Config(
    hidden_size=1024,
    intermediate_size=2816,
    num_hidden_layers=24,
    num_attention_heads=16,
    num_key_value_heads=16,
    max_window_layers=21,
    rope_theta=1000000.0,
    tie_word_embeddings=True,
)
if local_rank == 0:
    print(config)

config.use_cache = False

if local_rank != 0:
    with torch.device("meta"):
        model = Qwen2ForCausalLM._from_config(config)
else:
    model = Qwen2ForCausalLM._from_config(config)
print(f"rank {local_rank}: Model is difinited.")
model = FSDP(
    model,
    auto_wrap_policy=functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            Qwen2DecoderLayer,
        },
    ),
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    cpu_offload=CPUOffload(offload_params=True),
    device_id=torch.cuda.current_device(),
    limit_all_gathers=False,
    sync_module_states=True,
    param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
    if local_rank != 0 else None,
    use_orig_params=True,
)

if local_rank != 0:
    print(">>>> Created Model.")
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f">>> The model has {trainable_params / 1e6} M  trainable parameters")

Expected behavior

When tie_word_embeddings=False is set, the code behaves normally. However, when I set tie_word_embeddings=True, rank 0 exits normally, but rank 1 gets stuck. The point where it gets stuck is shown in the following image. (When using accelerate, the behavior is the same.) image

younesbelkada commented 1 month ago

Thanks for the reproducer, looking into it now

younesbelkada commented 1 month ago

Hi @jiangjiadi I have spent some time looking into the issue and I was able to reproduce. Interestingly the script worked if you never init the model on the meta device.

Also note from the official pytorch docs:

As of PyTorch 1.12, FSDP only offers limited support for shared parameters (for example, setting one Linear layer’s weight to another’s). In particular, modules that share parameters must be wrapped as part of the same FSDP unit. If enhanced shared parameter support is needed for your use case, please ping https://github.com/pytorch/pytorch/issues/77724

I will keep investigating and let you know.

jiangjiadi commented 1 month ago

Hi @younesbelkada Thank you for looking into this issue. I appreciate your prompt response and I am looking forward to any updates.

Additionally, I've noticed that when the from_config method is called with DeepSpeed's zero3 enabled, the model gets pre-partitioned. Could a similar approach be adopted for FSDP initialization? Pre-partitioning the model at definition could potentially help mitigate OOM issues when training large models.

amyeroberts commented 1 week ago

cc @muellerzr @SunMarc