rinongal / textual_inversion

MIT License
2.9k stars 279 forks source link

What's the difference betwee sample_gs_* and sample_scaled_gs_* #136

Open Y-ichen opened 1 year ago

Y-ichen commented 1 year ago

When I was reproducing the results for the paper, I found that files named samplegs and sample_scaledgs are contemporary produced. The results for sample_scaledgs are good: samples_scaled_gs-003000_e-000010_b-000000 However I don't know what does a samplesgs file means: samples_gs-003000_e-000010_b-000000 Therefore I want to know the difference between them. Besides, what makes sample_scaledgs that better than samplegs ? Does the input image set acts as a role for producing sample_scaledgs* ? (for example, is the start_code for diffusion model initialized from one image from the given image set?)

rinongal commented 1 year ago

Hi,

You can read more info about the outputs here: https://github.com/rinongal/textual_inversion/issues/19, https://github.com/rinongal/textual_inversion/issues/34

tl;dr: The difference is that samples_scaled_gs uses a classifier guidance scale of 5.0, while samples_gs uses a guidance scale of 1.0. You only care about samples_scaled_gs since you're going to use high guidance scales at inference anyhow.

However, If your unscaled samples look too much like your concept, that's a good sign of overfitting.

Also - if you are trying to reproduce - keep in mind the paper predates Stable Diffusion and uses the original LDM. You're not going to get similar results with SD.

Y-ichen commented 1 year ago

Thanks! That's a good explanation! Actually I am using stable diffusion 1.5 and trying to do some experiments with it, I think SD works well. If I want to add my own loss, where in the code should I add it? I am not familiar with pytorch_lightning and have difficult to find where the loss is fixed. The loss I want to add is about the generated image with an gt_image given by myself.

rinongal commented 1 year ago

You can add losses inside this function: https://github.com/rinongal/textual_inversion/blob/26ed44fb62c00d6a39d26212a0510466cccebd59/ldm/models/diffusion/ddpm.py#L1053

You'll probably have to understand how to pipe your data into that function, however.