Open brycegoh opened 1 week ago
@brycegoh, applying the weights would make the code more complex. To simplify it (since this feature is mostly unnecessary and only used for image editing tasks), we haven't implemented this feature yet. We plan to add it later.
@staoxiao I see. Then, is the released model weights trained using the weighted loss function or just MSE?
Yes, the model use weighted loss for the editing task during training.
@staoxiao, in that case, would it be possible to release that version of the training script? Or at least release the code for the weighted loss function that your team has already implemented as a reference?
Sure, I will add this function within this week.
@staoxiao, thank you!
Hi, @brycegoh, I updated loss.py
and train.py
, and you can refer to these codes. To run it, you'll also need to modify data.py
: select which data to weight, you can add a special string to your task_type
for filtering.
@staoxiao With reference to your code in train.py- when choosing whether to reinforce, zero out, or ignore reinforcement entirely, was this an empirical decision or did the team just come up with this approach and stuck with it?
We tried implementing this while training an image editing diffusion model, and found that sometimes the model would generate almost identical output given input- which would fall under your zero-weight out scenario. I will give this a try and see if it improves things.
@rphly , these hyperparameters were simply chosen based on the observation of a few examples, and since the results met our expectations, I did not try other settings.
As mentioned in the research paper section 2.2, for image editing tasks, the loss function was modified to amplify the difference between the input and target image.
However, it seems like the released code did not include that loss function when the task is
image_edit
or am I misunderstanding the code?Code snippet for reference: https://github.com/VectorSpaceLab/OmniGen/blob/2ef5c32fa9b96993cacea34cd8e24ea04c800c59/OmniGen/train_helper/loss.py#L51-L59
Loss function in question:
Thanks!