tianweiy / DMD2

(NeurIPS 2024 Oral 🔥) Improved Distribution Matching Distillation for Fast Image Synthesis
Other
446 stars 25 forks source link

is this exepcted when you are training sdv1.5 in stage2: gen_cls_loss (g loss) keeps increasing while guidance_cls_loss (d loss) deacreases. #39

Open fire2323 opened 1 month ago

fire2323 commented 1 month ago

1.when using the experiments script (sdv1.5 stage 2.): laion6.25_sd_baseline_8node_guidance1.75_lr5e-7_seed10_dfake10_diffusion1000_gan1e-3_noode_resume_fixdata.sh, gen_cls_loss (g loss) keeps increasing while guidance_cls_loss (d loss) deacreases. Is this the case in your training ? (ps: In stage 1, the FID has been achieved with your code.)

  1. I expected the g loss to be decreasing as the d loss does. But the FID keeps decreasing until close (a bit of above) to the level in your paper. I use batch size smaller than yours. Would this affects the g loss performance?

  2. And, I noticed that the weighted g loss is one magnitude smaller than the dmd loss, like g loss is 0.001 while dmd loss is 0.01. Is this the case in your training and expected ?

Thanks !

tianweiy commented 1 month ago

the training curves look like the following.

Screenshot 2024-08-06 at 12 51 31 AM

and the fid curve looks like the following

Screenshot 2024-08-06 at 12 52 52 AM

My impression is that the coco fid under low guidance regime is not super meaningful so I didn't spend too much time analyzing it...

I expected the g loss to be decreasing as the d loss does. But the FID keeps decreasing until close (a bit of above) to the level in your paper. I use batch size smaller than yours. Would this affects the g loss performance?

If d loss decreases, that means that the classifier can classify the generated sample well, then the g loss is expected to get larger right? To make g loss smaller, I remember i tried to use larger gen_cls_loss_weight but this would lead to training instability.

And, I noticed that the weighted g loss is one magnitude smaller than the dmd loss, like g loss is 0.001 while dmd loss is 0.01. Is this the case in your training and expected ?

The value is not exactly comparable. basically, we have a mean reduction here https://github.com/tianweiy/DMD2/blob/0f8a481716539af7b2795740c9763a7d0d05b83b/main/sd_guidance.py#L241 so the gradient norm corresponds to the dmd loss is actually scaled smaller (by the shape of the tensor).

fire2323 commented 1 month ago

Thanks for sharing! Another issue I am curious about is the cfg. Did you try to train a cfg unet for the generator (or/and fake diffusion model) as you did for the real diffusion model (i.e. 8 for sdxl)? And I am wondering three things: 1. if it will have further performance enhancement after using cfg training or not? 2. can the generator without cfg training be used for inference with cfg ? 3. if trained using some cfg guidance scale (e.g. 8) for the generator (or/and fake diffusion model), can the final output generator be used for inference with other guidance scales (e.g. 6 or 9)? Would it be necessary to train different models with different specific guidance scales and why?

tianweiy commented 1 month ago
  1. if it will have further performance enhancement after using cfg training or not?

I tried a few ways to train the fake diffusion model with CFG too. But I didn't manage to get better results than no cfg.

  1. can the generator without cfg training be used for inference with cfg

I think probably no because the current prediction target for the distilled model is not correlated to score anymore (it is a sharp image). Though I might be wrong. I haven't followed recent developments too much.

  1. Would it be necessary to train different models with different specific guidance scales and why?

I think it might be possible to do guidance conditioning (i.e. generator gets an extra parameter to indicate what the current guidance will be). I briefly explored this in the past but didn't get it to work (the generator trained with varying guidance is quite a bit worse than trained with a single guidance). I think it is probably a bit challenging considering that all other few-step generators are mostly fixed to a single guidance. But it might worth exploring

fire2323 commented 1 month ago

got it, thanks, how about the performance by training the generator with a single guidance ? does it work in inference?

tianweiy commented 1 month ago

training the generator with a single guidance ? does it work in inference

What does this mean ? We are already training with a single guidance now?

fire2323 commented 1 month ago

oh, sorry for the misunderstanding. From your description

the generator trained with varying guidance is quite a bit worse than trained with a single guidance

, I supposed that you had trained with a single guidance for the generator. So it leads to the question above: If that is the case, how about the performance compared to generator trained without a guidance? (to clarify, when I say the generator , I mean the self.feedforward_model in class SDUniModel (code lines 39). And that is also the meaning for generator as in your reply right?)

Thank you again for your great work ! :-)

tianweiy commented 1 month ago

If that is the case, how about the performance compared to generator trained without a guidance

Could you elaborate on this ? If we don't apply guidance for the real unet (aka train generator without guidance). the image is really really bad

fire2323 commented 1 month ago

There are 3 unets in the paper: NET 1: the generator (feedforward_model) which is the output network and used in inference; NET 2. the real unet; NET 3. the fake unet.

I mean NET 1 (the generator), when I am saying the generator in my post above (not NET 2), and I mean applying a guidance on NET 1 (also not NET 2):

If that is the case, how about the performance compared to generator trained without a guidance

In the code, a guidance is applied to NET 2 (the real unet indeed), but not applied to NET 1. I mean how the comparison would be for NET 1, regarding to applying guidance on NET 1: a) trained with a guidance applying on NET 1 and b) not trained with a guidance applying on NET 1 as it is now in the code. Below is the related code not applied a guidance on NET 1. "generated_noise" is dirrectly used to generate x0 not applying a guidance on the generated_noise.

generated_noise = self.feedforward_model( noisy_image, timesteps.long(), text_embedding, added_cond_kwargs=unet_added_conditions ).sample

generated_image = get_x0_from_noise( noisy_image, generated_noise.double(), self.alphas_cumprod.double(), current_timesteps ).float()

tianweiy commented 1 month ago

ic. I never tried this setting and theoretically I don't know what applying guidance to NET 1 mean mathematically for these one / few step samplers.