Open leeruibin opened 12 hours ago
Another bug is that when I use accelerator.save_state() to save the checkpoint, it will raise the OOM error
if accelerator.sync_gradients:
global_step += 1
if global_step % 5 == 0:
if accelerator.is_main_process:
save_path = os.path.join(output_dir, f"checkpoint_{global_step}")
accelerator.save_state(save_path)
print(f"Saved Unet to {save_path}")
The output log is:
/miniconda3/envs/SDXL/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py:90: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
dist_cp.save_state_dict(
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/ubuntu/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/pydevd.py", line 3713, in <module>
[rank0]: main()
[rank0]: File "/home/ubuntu/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/pydevd.py", line 3706, in main
[rank0]: globals = debugger.run(setup["file"], None, None, is_module)
[rank0]: File "/home/ubuntu/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/pydevd.py", line 2704, in run
[rank0]: return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
[rank0]: File "/home/ubuntu/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/pydevd.py", line 2712, in _exec
[rank0]: globals = pydevd_runpy.run_path(file, globals, "__main__")
[rank0]: File "/home/ubuntu/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 310, in run_path
[rank0]: return _run_module_code(code, init_globals, run_name, pkg_name=pkg_name, script_name=fname)
[rank0]: File "/home/ubuntu/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 127, in _run_module_code
[rank0]: _run_code(code, mod_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
[rank0]: File "/home/ubuntu/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 118, in _run_code
[rank0]: exec(code, run_globals)
[rank0]: File "tmp_test_FSDP.py", line 147, in <module>
[rank0]: main()
[rank0]: File "tmp_test_FSDP.py", line 139, in main
[rank0]: accelerator.save_state(save_path)
[rank0]: File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/accelerate/accelerator.py", line 2947, in save_state
[rank0]: save_fsdp_model(self.state.fsdp_plugin, self, model, output_dir, i)
[rank0]: File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 90, in save_fsdp_model
[rank0]: dist_cp.save_state_dict(
[rank0]: File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/typing_extensions.py", line 2853, in wrapper
[rank0]: return arg(*args, **kwargs)
[rank0]: File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 47, in save_state_dict
[rank0]: return _save_state_dict(
[rank0]: File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 316, in _save_state_dict
[rank0]: central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step)
[rank0]: File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 168, in reduce_scatter
[rank0]: all_data = self.gather_object(local_data)
[rank0]: File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 107, in gather_object
[rank0]: dist.gather_object(
[rank0]: File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
[rank0]: return func(*args, **kwargs)
[rank0]: File "/miniconda3/envs/SDXL/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2828, in gather_object
[rank0]: input_tensor.resize_(max_object_size)
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate more than 1EB memory.
W1122 07:19:41.246000 498926 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 499033 closing signal SIGTERM
System Info
Information
Tasks
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
)Reproduction
I am trying to use FSDP to accelerate my training with accelerator. The task is similar with SDXL-inpainting. However, when I try to save the intermediate checkpoints, the training script will hang on in the main thread after the checkpoint is saved with FSDP config. Here is my example code for reproduce the bug, I use the UNet2DConditionModel in the diffusers as the training model.
The FSDP FSDP_config.yaml is:
I use two GPU to test the script and I try to use print to output some information to check where the model is stuck, and the final output information is (I save the model after 5 training steps):
The model seems to hang on after I call accelerator.get_state_dict() (which is recommanded by the official document), and after being stuck for a long time, it will return the following log info:
I launch the training with the following command:
accelerate launch --config_file FSDP_config.yaml tmp_test_FSDP.py
Expected behavior
I want the save the intermediate checkpoint during the training process. Meanwhile, I want to training the UNet with LoRA parameters so you can see in the example I try to filter the parameters with 'lora' in the key. Actually, I don't integrate LoRA into the Unet in this example code, so the final saved state_dict is empty. But I guess it may not effect the bug.