ExplainableML / ReNO

[NeurIPS 2024] ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization
MIT License
104 stars 8 forks source link

About Loss Function #8

Open Oguzhanercan opened 2 months ago

Oguzhanercan commented 2 months ago

Hi, I have a question about the loss function choice. As I see, you have used dot product for similarity of features. Did you experiment using L1, L2 MSE, others or their combination. I am asking because I have no enough gpu memory to use more than 1 reward function at this pipeline. If the feature space is orthogonal (At high dimensional space, this is expected), in my experiments, it is hard to optimize dot product - cosine similarity. I test it with rectifid, similar work to yours. When I used more than 1 face recognition network, it fails to generate images with same identity. If the feature space orthogonality is not a problem, is there a key to solve this problem. I thought that regularization in your work might solve this problem, but if I understand correctly, you did not used it for this purpose. @sgk98

sgk98 commented 1 month ago

Thanks for the suggestion! All the models we used (e.g. CLIP, PickScore, HPSv2.1) were mostly based on the CLIP architecture where it's meaningful to compute scores/similarities from taking a dot product with the prompt text and generated image features. While something else (e.g. L1/L2) might also work, these models weren't really trained to work with any other kind of similarity computation, and are trained to maximize the contrastive objective with the dot product.

I also looked into rectifid based on your suggestion, and I see that they also have a cosine loss with the dot product, but they seem to add an L1 loss which they say helps in some cases. But that might also be that for DINOv2/Arcface, L1 might be a good choice.

About the regularization, the reason we added it was to ensure that the noise does not deviate too much from the original distribution. Otherwise, it's very easy to increase the score but just make the image noisy. I am not too sure how helpful/necessary this is for you.

Once again, do share your codebase if possible for this experiment, I can try a few things and see if there's anything particularly helpful for you.