kohya-ss / sd-scripts

Apache License 2.0
5.33k stars 881 forks source link

FSDP support #1775

Open ljleb opened 2 weeks ago

ljleb commented 2 weeks ago

I tried using a FSDP config like this for accelerate (taken from https://github.com/kohya-ss/sd-scripts/issues/1480#issuecomment-2301283660) to finetune SDXL. The UI is bmaltais/kohya_ss

compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: SIZE_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: true
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
  fsdp_min_num_params: 100000000
machine_rank: 0
main_training_function: main
mixed_precision: fp32
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

But it gives me this error:

[...]
number of trainable parameters: 3385184004
prepare optimizer, data loader etc.
running training / 学習開始
  num examples / サンプル数: 100
  num batches per epoch / 1epochのバッチ数: 100
  num epochs / epoch数: 1
  batch size per device / バッチサイズ: 1
  gradient accumulation steps / 勾配を合計するステップ数 = 1
  total optimization steps / 学習ステップ数: 100
steps:   0%|                                                                                                 | 0/100 [00:00<?, ?it/s]
epoch 1/1
2024-11-11 08:08:31 INFO     epoch is incremented. current_epoch: 0, epoch: 1                                       train_util.py:703
Traceback (most recent call last):
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/sdxl_train.py", line 822, in <module>
    train(args)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/sdxl_train.py", line 614, in train
    noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 1104, in forward
    h = call_module(module, h, emb, context)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 1097, in call_module
    x = layer(x)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 750, in forward
    hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 669, in forward
    output = torch.utils.checkpoint.checkpoint(
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 451, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 230, in forward
    outputs = run_function(*args)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 665, in custom_forward
    return func(*inputs)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 652, in forward_body
    hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 445, in forward
    return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
  File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 524, in forward_memory_efficient_mem_eff
    k = self.to_k(context)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (576x1280 and 2048x1280)
steps:   0%|                                                                                                 | 0/100 [00:00<?, ?it/s]
[2024-11-11 08:08:34,959] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 232842) of binary: /home/ljleb/src/kohya_ss/venv/bin/python

I have 2 P40s (which means only fp32 is practical; 24GB using a single card appears to be not enough memory) and I would like to distribute the parameters of a single SDXL model over multiple GPUs to reduce the memory usage per card of traditional finetuning and not any type of PEFT.

I am not very familiar with using multiple GPUs to train models. Can FSDP put the first half of SDXL in cuda:0 and the other half in cuda:1 with the existing code?

ljleb commented 2 weeks ago

I figured I can put the text encoders on a different device manually. It's a very ad-hoc solution but it works. I think it would be great if there was a way to partition the load over a larger number of devices.