Open aravindhv10 opened 2 months ago
Hi, there is an issue with broadcasting when using batch size > 1 with the line: https://github.com/XLabs-AI/x-flux/blob/main/train_flux_lora_deepspeed.py#L254 the shape of t should be (batch_size, 1, 1) to be broadcasted and multiplied with x_0 and x_t
You can try and fix it this way:
https://github.com/aravindhv10/x-flux/blob/aravind_prodigy_dataset/train_flux_lora_deepspeed.py#L253-L256
or perhaps there might be a better way? The above method worked for me.
This link provides the details of broadcasting:
https://stackoverflow.com/questions/65121614/pytorch-how-to-multiply-via-broadcasting-of-two-tensors-with-different-shapes
https://pytorch.org/docs/stable/notes/broadcasting.html
Raised PR:
https://github.com/XLabs-AI/x-flux/pull/104
Hi, there is an issue with broadcasting when using batch size > 1 with the line: https://github.com/XLabs-AI/x-flux/blob/main/train_flux_lora_deepspeed.py#L254 the shape of t should be (batch_size, 1, 1) to be broadcasted and multiplied with x_0 and x_t
You can try and fix it this way:
https://github.com/aravindhv10/x-flux/blob/aravind_prodigy_dataset/train_flux_lora_deepspeed.py#L253-L256
or perhaps there might be a better way? The above method worked for me.