luosiallen / latent-consistency-model

Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference
MIT License
4.27k stars 222 forks source link

w_embedding = guidance_scale_embedding is missing in the train_lcm_distill_sdxl_wds.py #81

Open PetrByvsh opened 8 months ago

PetrByvsh commented 8 months ago

In the train_lcm_distill_sd_wds.py

20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it

            w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
            w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
            w = w.reshape(bsz, 1, 1, 1)
            # Move to U-Net device and dtype
            w = w.to(device=latents.device, dtype=latents.dtype)
            w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)

train_lcm_distill_sdxl_wds.py : w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w = w.reshape(bsz, 1, 1, 1) w = w.to(device=latents.device, dtype=latents.dtype)

Any reason for this? The code for XL model does not work without it (it defines
noise_pred = unet( noisy_model_input, start_timesteps, timestep_cond=None, encoder_hidden_states=prompt_embeds.float(), added_cond_kwargs=encoded_text, ).sample The timestep_cond is None, although the "unet_time_cond_proj_dim" is still required as raised in the other issue.

shuminghu commented 8 months ago

If you use the latest code from diffuser, the error is fixed by this one line change.

diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
index ee86def6..a49e5f26 100644
--- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
@@ -948,7 +948,7 @@ def main(args):
     # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
     if teacher_unet.config.time_cond_proj_dim is None:
         teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
-    time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
+    time_cond_proj_dim = teacher_unet.config["time_cond_proj_dim"]
     unet = UNet2DConditionModel(**teacher_unet.config)
     # load teacher_unet weights into unet
     unet.load_state_dict(teacher_unet.state_dict(), strict=False)
Neville0302 commented 3 months ago

请问您这个问题怎么解决的?