[...]
number of trainable parameters: 3385184004
prepare optimizer, data loader etc.
running training / 学習開始
num examples / サンプル数: 100
num batches per epoch / 1epochのバッチ数: 100
num epochs / epoch数: 1
batch size per device / バッチサイズ: 1
gradient accumulation steps / 勾配を合計するステップ数 = 1
total optimization steps / 学習ステップ数: 100
steps: 0%| | 0/100 [00:00<?, ?it/s]
epoch 1/1
2024-11-11 08:08:31 INFO epoch is incremented. current_epoch: 0, epoch: 1 train_util.py:703
Traceback (most recent call last):
File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/sdxl_train.py", line 822, in <module>
train(args)
File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/sdxl_train.py", line 614, in train
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 1104, in forward
h = call_module(module, h, emb, context)
File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 1097, in call_module
x = layer(x)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 750, in forward
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 669, in forward
output = torch.utils.checkpoint.checkpoint(
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 451, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 230, in forward
outputs = run_function(*args)
File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 665, in custom_forward
return func(*inputs)
File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 652, in forward_body
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 445, in forward
return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
File "/home/ljleb/src/kohya_ss_fisher/sd-scripts/library/sdxl_original_unet.py", line 524, in forward_memory_efficient_mem_eff
k = self.to_k(context)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ljleb/src/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (576x1280 and 2048x1280)
steps: 0%| | 0/100 [00:00<?, ?it/s]
[2024-11-11 08:08:34,959] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 232842) of binary: /home/ljleb/src/kohya_ss/venv/bin/python
I have 2 P40s (which means only fp32 is practical; 24GB using a single card appears to be not enough memory) and I would like to distribute the parameters of a single SDXL model over multiple GPUs to reduce the memory usage per card of traditional finetuning and not any type of PEFT.
I am not very familiar with using multiple GPUs to train models. Can FSDP put the first half of SDXL in cuda:0 and the other half in cuda:1 with the existing code?
I figured I can put the text encoders on a different device manually. It's a very ad-hoc solution but it works. I think it would be great if there was a way to partition the load over a larger number of devices.
I tried using a FSDP config like this for accelerate (taken from https://github.com/kohya-ss/sd-scripts/issues/1480#issuecomment-2301283660) to finetune SDXL. The UI is bmaltais/kohya_ss
But it gives me this error:
I have 2 P40s (which means only fp32 is practical; 24GB using a single card appears to be not enough memory) and I would like to distribute the parameters of a single SDXL model over multiple GPUs to reduce the memory usage per card of traditional finetuning and not any type of PEFT.
I am not very familiar with using multiple GPUs to train models. Can FSDP put the first half of SDXL in cuda:0 and the other half in cuda:1 with the existing code?