XLabs-AI / x-flux

Apache License 2.0
1.63k stars 118 forks source link

Broadcasting issue when using training batch size > 1 #102

Open aravindhv10 opened 2 months ago

aravindhv10 commented 2 months ago

Hi, there is an issue with broadcasting when using batch size > 1 with the line: https://github.com/XLabs-AI/x-flux/blob/main/train_flux_lora_deepspeed.py#L254 the shape of t should be (batch_size, 1, 1) to be broadcasted and multiplied with x_0 and x_t

You can try and fix it this way:

https://github.com/aravindhv10/x-flux/blob/aravind_prodigy_dataset/train_flux_lora_deepspeed.py#L253-L256

or perhaps there might be a better way? The above method worked for me.

aravindhv10 commented 2 months ago

This link provides the details of broadcasting:

https://stackoverflow.com/questions/65121614/pytorch-how-to-multiply-via-broadcasting-of-two-tensors-with-different-shapes

https://pytorch.org/docs/stable/notes/broadcasting.html

aravindhv10 commented 2 months ago

Raised PR:

https://github.com/XLabs-AI/x-flux/pull/104