huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.97k stars 970 forks source link

Accelerate + FSDP plugin hang on after model save intermediate checkpoint #3250

Open leeruibin opened 12 hours ago

leeruibin commented 12 hours ago

System Info

- `Accelerate` version: 0.33.0
- `accelerate` bash location: /miniconda3/envs/SDXL/bin/accelerate
- Python version: 3.10.14
- Numpy version: 1.24.4
- PyTorch version (GPU?): 2.5.1+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch MUSA available: False
- System RAM: 1121.82 GB
- GPU type: NVIDIA A100-SXM4-80GB
- `Accelerate` default config:
        Not found

Python version:
Python 3.10.14

diffusers Version:
diffusers                 0.30.0

pytorch version
torch                     2.5.1+cu121

transformers version
transformers              4.44.0

Information

Tasks

Reproduction

I am trying to use FSDP to accelerate my training with accelerator. The task is similar with SDXL-inpainting. However, when I try to save the intermediate checkpoints, the training script will hang on in the main thread after the checkpoint is saved with FSDP config. Here is my example code for reproduce the bug, I use the UNet2DConditionModel in the diffusers as the training model.

#!/usr/bin/env python
import os
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from diffusers import UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.utils import is_wandb_available
from diffusers.utils.torch_utils import is_compiled_module
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig

def main():
    output_dir = "tmp/test_FSDP"
    os.makedirs(output_dir,exist_ok=True)
    logging_dir = os.path.join(output_dir, "logs")
    accelerator_project_config = ProjectConfiguration(project_dir="test_FSDP", logging_dir=logging_dir)
    accelerator = Accelerator(
        gradient_accumulation_steps=1,
        mixed_precision="no",
        project_config=accelerator_project_config,
    )

    unet = UNet2DConditionModel.from_pretrained(
        "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="unet", revision=None, variant=None
    )
    unet.requires_grad_(True)
    gradient_checkpointing = False

    def unwrap_model(model):
        model = accelerator.unwrap_model(model)
        model = model._orig_mod if is_compiled_module(model) else model
        return model

    if gradient_checkpointing:
        unet.enable_gradient_checkpointing()

    params_to_optimize = list(unet.parameters())

    optimizer_cls = torch.optim.AdamW

    optimizer = optimizer_cls(
        # unet.parameters(),
        filter(lambda p: p.requires_grad, params_to_optimize),
        lr=0.00001,
        betas=(0.9, 0.999),
        weight_decay=1e-2,
        eps=1e-08,
    )

    # Set UNet to trainable.
    unet.train()

    lr_scheduler = get_scheduler(
        "constant",
        optimizer=optimizer,
    )

    # Prepare everything with our `accelerator`.
    unet, optimizer, lr_scheduler = accelerator.prepare(
        unet, optimizer, lr_scheduler
    )

    global_step = 0
    for step in range(0, 1000):
        with accelerator.accumulate(unet):
            # [2,9,64,64] [2] [2,77,2048] text_embeds=[2,1280] time_ids[2,6]
            unet_input = torch.randn([2,9,64,64]).to(unet.device,dtype=unet.dtype)
            timesteps = torch.randint(1,100,[2]).to(unet.device).long()
            encoder_hidden_states = torch.randn([2,77,2048]).to(unet.device,dtype=unet.dtype)
            text_embeds = torch.randn([2,1280]).to(unet.device,dtype=unet.dtype)
            time_ids = torch.randn([2,6]).to(unet.device,dtype=unet.dtype)
            added_cond_kwargs = {
                "text_embeds":text_embeds,
                "time_ids":time_ids
            }

            print(f"This is thread {accelerator.local_process_index} with training step {step} command 1" )

            model_pred = unet(
                unet_input,
                timesteps,
                encoder_hidden_states,
                added_cond_kwargs=added_cond_kwargs,
                return_dict=False,
            )[0]

            print(f"This is thread {accelerator.local_process_index} with training step {step} command 2" )

            target = torch.randn_like(model_pred)

            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
            accelerator.backward(loss)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(unet.parameters(), 1.0)

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            accelerator.wait_for_everyone()

            print(f"This is thread {accelerator.local_process_index} with training step {step} command 3" )

            if accelerator.sync_gradients:
                global_step += 1
                if global_step % 5 == 0:
                    if accelerator.is_main_process:
                        save_path = os.path.join(output_dir, f"checkpoint_{global_step}_lora_dict.pkl")
                        # accelerator.save_state(save_path)
                        # unet_lora_layers_to_save = get_peft_model_state_dict(unwrap_model(unet))
                        state_dict = accelerator.get_state_dict(unet,unwrap=False)
                        unet_lora_layers_to_save = {}
                        for key,value in state_dict.items():
                            if 'lora' in key or 'Lora' in key:
                                unet_lora_layers_to_save[key] = value
                        lora_dict = {"unet_lora":unet_lora_layers_to_save}
                        torch.save(lora_dict,save_path)
                        print(f"Saved Unet to {save_path}")

    accelerator.wait_for_everyone()
    accelerator.end_training()

if __name__ == "__main__":
    main()

The FSDP FSDP_config.yaml is:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
  fsdp_activation_checkpointing: false
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: false
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: SHARD_GRAD_OP
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_transformer_layer_cls_to_wrap: CrossAttnDownBlock2D,UNetMidBlock2DCrossAttn,CrossAttnUpBlock2D
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

I use two GPU to test the script and I try to use print to output some information to check where the model is stuck, and the final output information is (I save the model after 5 training steps):

This is thread 1 with training step 5 command 1
This is thread 1 with training step 5 command 2
Saved Unet to tmp/test_FSDP/checkpoint_5_lora_dict.pkl
This is thread 0 with training step 5 command 1
This is thread 0 with training step 5 command 2
This is thread 1 with training step 5 command 3
This is thread 1 with training step 6 command 1

The model seems to hang on after I call accelerator.get_state_dict() (which is recommanded by the official document), and after being stuck for a long time, it will return the following log info:

[rank0]:[E1122 06:40:00.504052851 ProcessGroupNCCL.cpp:616] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=122, OpType=_REDUCE_SCATTER_BASE, NumelIn=106343680, NumelOut=53171840, Timeout(ms)=600000) ran for 600080 milliseconds before timing out.
[rank0]:[E1122 06:40:00.504952916 ProcessGroupNCCL.cpp:1785] [PG ID 0 PG GUID 0(default_pg) Rank 0] Exception (either an error or timeout) detected by watchdog at work: 122, last enqueued NCCL work: 127, last completed NCCL work: 121.
[rank1]:[E1122 06:40:01.902604829 ProcessGroupNCCL.cpp:616] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=126, OpType=_ALLGATHER_BASE, NumelIn=378702080, NumelOut=757404160, Timeout(ms)=600000) ran for 600079 milliseconds before timing out.
[rank1]:[E1122 06:40:01.903028278 ProcessGroupNCCL.cpp:1785] [PG ID 0 PG GUID 0(default_pg) Rank 1] Exception (either an error or timeout) detected by watchdog at work: 126, last enqueued NCCL work: 128, last completed NCCL work: 125.
[rank1]:[E1122 06:40:02.586033171 ProcessGroupNCCL.cpp:1834] [PG ID 0 PG GUID 0(default_pg) Rank 1] Timeout at NCCL work: 126, last enqueued NCCL work: 128, last completed NCCL work: 125.
[rank1]:[E1122 06:40:02.586050386 ProcessGroupNCCL.cpp:630] [Rank 1] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank1]:[E1122 06:40:02.586057304 ProcessGroupNCCL.cpp:636] [Rank 1] To avoid data inconsistency, we are taking the entire process down.
[rank1]:[E1122 06:40:02.587199000 ProcessGroupNCCL.cpp:1595] [PG ID 0 PG GUID 0(default_pg) Rank 1] Process group watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=126, OpType=_ALLGATHER_BASE, NumelIn=378702080, NumelOut=757404160, Timeout(ms)=600000) ran for 600079 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f88bc3a1446 in /miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7f885f02fa92 in /miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f885f036ed3 in /miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f885f03893d in /miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7f88c4ba45c0 in /miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x8609 (0x7f88c675c609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #6: clone + 0x43 (0x7f88c6525353 in /lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG ID 0 PG GUID 0(default_pg) Rank 1] Process group watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=126, OpType=_ALLGATHER_BASE, NumelIn=378702080, NumelOut=757404160, Timeout(ms)=600000) ran for 600079 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f88bc3a1446 in /miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7f885f02fa92 in /miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f885f036ed3 in /miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f885f03893d in /miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x7f88c4ba45c0 in /miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x8609 (0x7f88c675c609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #6: clone + 0x43 (0x7f88c6525353 in /lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1601 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f88bc3a1446 in /miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe7eb1b (0x7f885ecadb1b in /miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0x145c0 (0x7f88c4ba45c0 in /miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x8609 (0x7f88c675c609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #4: clone + 0x43 (0x7f88c6525353 in /lib/x86_64-linux-gnu/libc.so.6)

W1122 06:40:02.533000 481990 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 482092 closing signal SIGTERM
E1122 06:40:03.600000 481990 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: -6) local_rank: 1 (pid: 482093) of binary: /miniconda3/envs/SDXL/bin/python

I launch the training with the following command: accelerate launch --config_file FSDP_config.yaml tmp_test_FSDP.py

Expected behavior

I want the save the intermediate checkpoint during the training process. Meanwhile, I want to training the UNet with LoRA parameters so you can see in the example I try to filter the parameters with 'lora' in the key. Actually, I don't integrate LoRA into the Unet in this example code, so the final saved state_dict is empty. But I guess it may not effect the bug.

leeruibin commented 11 hours ago

Another bug is that when I use accelerator.save_state() to save the checkpoint, it will raise the OOM error

if accelerator.sync_gradients:
    global_step += 1
    if global_step % 5 == 0:
        if accelerator.is_main_process:
            save_path = os.path.join(output_dir, f"checkpoint_{global_step}")
            accelerator.save_state(save_path)
            print(f"Saved Unet to {save_path}")

The output log is:

/miniconda3/envs/SDXL/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py:90: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
  dist_cp.save_state_dict(
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/ubuntu/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/pydevd.py", line 3713, in <module>
[rank0]:     main()
[rank0]:   File "/home/ubuntu/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/pydevd.py", line 3706, in main
[rank0]:     globals = debugger.run(setup["file"], None, None, is_module)
[rank0]:   File "/home/ubuntu/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/pydevd.py", line 2704, in run
[rank0]:     return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
[rank0]:   File "/home/ubuntu/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/pydevd.py", line 2712, in _exec
[rank0]:     globals = pydevd_runpy.run_path(file, globals, "__main__")
[rank0]:   File "/home/ubuntu/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 310, in run_path
[rank0]:     return _run_module_code(code, init_globals, run_name, pkg_name=pkg_name, script_name=fname)
[rank0]:   File "/home/ubuntu/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 127, in _run_module_code
[rank0]:     _run_code(code, mod_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
[rank0]:   File "/home/ubuntu/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 118, in _run_code
[rank0]:     exec(code, run_globals)
[rank0]:   File "tmp_test_FSDP.py", line 147, in <module>
[rank0]:     main()
[rank0]:   File "tmp_test_FSDP.py", line 139, in main
[rank0]:     accelerator.save_state(save_path)
[rank0]:   File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/accelerate/accelerator.py", line 2947, in save_state
[rank0]:     save_fsdp_model(self.state.fsdp_plugin, self, model, output_dir, i)
[rank0]:   File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 90, in save_fsdp_model
[rank0]:     dist_cp.save_state_dict(
[rank0]:   File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/typing_extensions.py", line 2853, in wrapper
[rank0]:     return arg(*args, **kwargs)
[rank0]:   File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 47, in save_state_dict
[rank0]:     return _save_state_dict(
[rank0]:   File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 316, in _save_state_dict
[rank0]:     central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step)
[rank0]:   File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 168, in reduce_scatter
[rank0]:     all_data = self.gather_object(local_data)
[rank0]:   File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 107, in gather_object
[rank0]:     dist.gather_object(
[rank0]:   File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2828, in gather_object
[rank0]:     input_tensor.resize_(max_object_size)
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate more than 1EB memory.
W1122 07:19:41.246000 498926 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 499033 closing signal SIGTERM