kohya-ss / sd-scripts

Apache License 2.0
5.32k stars 880 forks source link

SD3.5 finetuning crashes due to dtype difference in vae call when not caching latents (potentially DDP related?) #1758

Open yoinked-h opened 3 weeks ago

yoinked-h commented 3 weeks ago
[rank1]: Traceback (most recent call last):
[rank1]:   File "/dockercontainer/sd-scripts/sd3_train.py", line 1200, in <module>
[rank1]:     train(args)
[rank1]:   File "/dockercontainer/sd-scripts/sd3_train.py", line 871, in train
[rank1]:     latents = vae.encode(batch["images"])
[rank1]:   File "/dockercontainer/sd-scripts/library/sd3_models.py", line 1435, in encode
[rank1]:     hidden = self.encoder(image)
[rank1]:   File "/dockercontainer/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/dockercontainer/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/dockercontainer/sd-scripts/library/sd3_models.py", line 1333, in forward
[rank1]:     hs = [self.conv_in(x)]
[rank1]:   File "/dockercontainer/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/dockercontainer/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/dockercontainer/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 554, in forward
[rank1]:     return self._conv_forward(input, self.weight, self.bias)
[rank1]:   File "/dockercontainer/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 549, in _conv_forward
[rank1]:     return F.conv2d(
 RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same
yoinked-h commented 3 weeks ago

This also happens with T5 dropout, i think

kohya-ss commented 3 weeks ago

Fixed VAE issue. I can't reproduce the issue with T5XXL, so please let me know how to do it.

yoinked-h commented 2 weeks ago

@kohya-ss T5xxl dtype mismatch happens with bs>1 and t5xxl dropout enabled, maybe ddp too thoughits a stretch, ill try and get error logs later today