XLabs-AI / x-flux

Apache License 2.0
1.33k stars 89 forks source link

train_flux_lora_deepspeed.py Error : The size of tensor a (2) must match the size of tensor b (64) #76

Open AfterHAL opened 3 weeks ago

AfterHAL commented 3 weeks ago

Well, this is new for me.

08/23/2024 10:47:34 - INFO - __main__ - ***** Running training *****
08/23/2024 10:47:34 - INFO - __main__ -   Num Epochs = 500
08/23/2024 10:47:34 - INFO - __main__ -   Instantaneous batch size per device = 2
08/23/2024 10:47:34 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 8
08/23/2024 10:47:34 - INFO - __main__ -   Gradient Accumulation steps = 2
08/23/2024 10:47:34 - INFO - __main__ -   Total optimization steps = 2000
Checkpoint 'latest' does not exist. Starting a new training run.
Steps:   0%|                                                                                                                                                                                                                 | 0/2000 [00:00<?, ?it/s][rank1]: Traceback (most recent call last):
[rank1]:   File "/workspace/x-flux/train_flux_lora_deepspeed.py", line 301, in <module>
[rank1]:     main()
[rank1]:   File "/workspace/x-flux/train_flux_lora_deepspeed.py", line 226, in main
[rank1]:     x_t = (1 - t) * x_1 + t * x_0
[rank1]: RuntimeError: The size of tensor a (2) must match the size of tensor b (64) at non-singleton dimension 2
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/x-flux/train_flux_lora_deepspeed.py", line 301, in <module>
[rank0]:     main()
[rank0]:   File "/workspace/x-flux/train_flux_lora_deepspeed.py", line 226, in main
[rank0]:     x_t = (1 - t) * x_1 + t * x_0
[rank0]: RuntimeError: The size of tensor a (2) must match the size of tensor b (64) at non-singleton dimension 2
Steps:   0%|                                                                                                                                                                                                                 | 0/2000 [00:01<?, ?it/s]
[rank0]:[W823 10:47:36.712866358 ProcessGroupNCCL.cpp:1168] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())
E0823 10:47:40.287000 131962042499072 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 0 (pid: 3328) of binary: /workspace/x-flux/xflux_env/bin/python3
Traceback (most recent call last):
  File "/workspace/x-flux/xflux_env/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/workspace/x-flux/xflux_env/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 46, in main
    args.func(args)
  File "/workspace/x-flux/xflux_env/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1067, in launch_command
    deepspeed_launcher(args)
  File "/workspace/x-flux/xflux_env/lib/python3.10/site-packages/accelerate/commands/launch.py", line 771, in deepspeed_launcher
    distrib_run.run(args)
  File "/workspace/x-flux/xflux_env/lib/python3.10/site-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/workspace/x-flux/xflux_env/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/workspace/x-flux/xflux_env/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
train_flux_lora_deepspeed.py FAILED
stazizov commented 2 weeks ago

seems like the bug with reshaping when batch size > 1

YujiaKCL commented 2 weeks ago

Hi, I already encountered this issue. Have you found the solution?