huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.17k stars 5.39k forks source link

FLUX.1-dev dreambooth training problem on multigpu #9790

Open jyy-1998 opened 2 weeks ago

jyy-1998 commented 2 weeks ago

Describe the bug

I tried to use accelerate+deepspeed to train flux, but every time after a dozen steps, an error occurred and the program crashed. Can anyone provide some help?

Reproduction

accelerate launch --config_file config.yaml train_flux.py --pretrained_model_name_or_path="./FLUX.1-dev" --resolution=1024 --train_batch_size=1 --output_dir="output0" --num_train_epochs=10 --checkpointing_steps=5000 --validation_steps=100 --max_train_steps=40001 --learning_rate=4e-05 --seed=12345 --mixed_precision="fp16" --revision="fp16" --use_8bit_adam --gradient_accumulation_steps=1 --gradient_checkpointing

compute_environment: LOCAL_MACHINE deepspeed_config: gradient_accumulation_steps: 1 gradient_clipping: 1.0 offload_optimizer_device: cpu offload_param_device: cpu zero3_init_flag: true zero_stage: 2 distributed_type: DEEPSPEED downcast_bf16: 'no' gpu_ids: 0,1 enable_cpu_affinity: false machine_rank: 0 main_training_function: main mixed_precision: fp16 num_machines: 1 num_processes: 2 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false

Logs

the logs:
Steps: 0%| | 0/40001 [00:00<?, ?it/s]Passing txt_ids 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
[2024-10-29 06:19:05,281] [INFO] [loss_scaler.py:183:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 4294967296, reducing to 2147483648
Steps: 0%| | 1/40001 [00:14<165:03:11, 14.85s/it, loss=0.707, lr=0]Passing txt_ids 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
[2024-10-29 06:19:18,776] [INFO] [loss_scaler.py:183:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 2147483648, reducing to 1073741824
Steps: 0%| | 2/40001 [00:28<156:09:31, 14.05s/it, loss=0.428, lr=0]Passing txt_ids 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
[2024-10-29 06:19:31,289] [INFO] [loss_scaler.py:183:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1073741824, reducing to 536870912
Steps: 0%| | 3/40001 [00:40<148:20:21, 13.35s/it, loss=0.652, lr=0]Passing txt_ids 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
[2024-10-29 06:19:43,569] [INFO] [loss_scaler.py:183:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 536870912, reducing to 268435456
Steps: 0%| | 4/40001 [00:53<143:38:06, 12.93s/it, loss=0.919, lr=0]Passing txt_ids 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
[2024-10-29 06:19:55,578] [INFO] [loss_scaler.py:183:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 268435456, reducing to 134217728
Steps: 0%| | 5/40001 [01:05<139:57:04, 12.60s/it, loss=0.439, lr=0]Passing txt_ids 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
[2024-10-29 06:20:07,647] [INFO] [loss_scaler.py:183:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 134217728, reducing to 67108864
Steps: 0%| | 6/40001 [01:17<137:57:02, 12.42s/it, loss=0.56, lr=0]Passing txt_ids 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
[2024-10-29 06:20:20,092] [INFO] [loss_scaler.py:183:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 67108864, reducing to 33554432
Steps: 0%| | 7/40001 [01:29<138:03:03, 12.43s/it, loss=0.694, lr=0]Passing txt_ids 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
[2024-10-29 06:20:31,513] [INFO] [loss_scaler.py:183:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 33554432, reducing to 16777216
Steps: 0%| | 8/40001 [01:41<134:29:31, 12.11s/it, loss=0.577, lr=0]Passing txt_ids 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
[2024-10-29 06:20:42,977] [INFO] [loss_scaler.py:183:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 16777216, reducing to 8388608
Steps: 0%| | 9/40001 [01:52<132:15:29, 11.91s/it, loss=0.363, lr=0]Passing txt_ids 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
[2024-10-29 06:20:54,396] [INFO] [loss_scaler.py:183:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 8388608, reducing to 4194304
Steps: 0%| | 10/40001 [02:03<130:35:01, 11.76s/it, loss=0.64, lr=0]Passing txt_ids 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
[2024-10-29 06:21:06,229] [INFO] [loss_scaler.py:183:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 4194304, reducing to 2097152
Steps: 0%| | 11/40001 [02:15<130:50:48, 11.78s/it, loss=0.574, lr=0]Passing txt_ids 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
[2024-10-29 06:21:17,911] [INFO] [loss_scaler.py:183:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 2097152, reducing to 1048576
Steps: 0%| | 12/40001 [02:27<130:30:42, 11.75s/it, loss=0.44, lr=0]Passing txt_ids 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
[2024-10-29 06:21:29,945] [INFO] [loss_scaler.py:183:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1048576, reducing to 524288
Steps: 0%| | 13/40001 [02:39<131:28:04, 11.84s/it, loss=0.387, lr=0]Passing txt_ids 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 11739 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -9) local_rank: 0 (pid: 11738) of binary: /opt/conda/bin/python
Traceback (most recent call last):
File "/opt/conda/bin/accelerate", line 8, in
sys.exit(main())
File "/opt/conda/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
args.func(args)
File "/opt/conda/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1082, in launch_command
deepspeed_launcher(args)
File "/opt/conda/lib/python3.10/site-packages/accelerate/commands/launch.py", line 786, in deepspeed_launcher
distrib_run.run(args)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py", line 785, in run
elastic_launch(
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in call
return launch_agent(self._config, self._entrypoint, list(args))
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
train_flux.py FAILED
Failures:
<NO_OTHER_FAILURES>
Root Cause (first observed failure):
[0]:
time : 2024-10-29_06:22:05
host : task-20241029103910-30817
rank : 0 (local_rank: 0)
exitcode : -9 (pid: 11738)
error_file: <N/A>
traceback : Signal 9 (SIGKILL) received by PID 11738

System Info

deepspeed==0.14.4 accelerate==0.33.0 transformers==4.41.2

Who can help?

No response

a-r-r-o-w commented 2 weeks ago

Could you enable verbose logging with Accelerate (ref) and paste the logs? This does not look like it contains any information that would help identify what the issue might be

lyb369 commented 1 week ago

Maybe it's because your cpu memory is not enough

sayakpaul commented 1 week ago

I have investigated this before and I can confirm it works. See: https://github.com/huggingface/diffusers/issues/9278#issuecomment-2410113103

leisuzz commented 1 week ago

Can you try regarding #9829 ? I have saved memory by implementing this :)