tianweiy / DMD2

(NeurIPS 2024 Oral 🔥) Improved Distribution Matching Distillation for Fast Image Synthesis
Other
524 stars 28 forks source link

muti-step model for sdv1.5 #36

Open kmpartner opened 4 months ago

kmpartner commented 4 months ago

Is it possible to train 4steps (multi-steps) model for sd-v1.5 in this repo? I see in sdxl experiment docs, but not in sdv1.5 experiment docs.

tianweiy commented 4 months ago

it is possible but you need some coding

kmpartner commented 4 months ago

Thank you for response. Which part of code and what kind of modification is required to enable multi-step training for sdv1.5? Is it denoising part?

tianweiy commented 4 months ago

i think you would need to modify this function https://github.com/tianweiy/DMD2/blob/0f8a481716539af7b2795740c9763a7d0d05b83b/main/sd_unified_model.py#L166

first remove this assert and adapt the text encoder and backward simulation code (if you need the later one). you can see this line for how to modify the text encoder https://github.com/tianweiy/DMD2/blob/0f8a481716539af7b2795740c9763a7d0d05b83b/main/sd_unified_model.py#L222

tianweiy commented 4 months ago

i can help reviewing your changes if you are interested to do a pull request

kmpartner commented 4 months ago

I updated text_embedding and pooled_text_embedding by text_encoder in prepare_denoising_data function. Is this right way to enable multi-step in sdv1.5?

Is denoising_timestep 250 in sdxl case? Is there any reasons to use 250 other than to limit 4 steps?

def prepare_denoising_data(self, denoising_dict, real_train_dict, noise):
    # assert self.sdxl, "Denoising is only supported for SDXL"

    indices = torch.randint(
        0, self.num_denoising_step, (noise.shape[0],), device=noise.device, dtype=torch.long
    )
    timesteps = self.denoising_step_list.to(noise.device)[indices]

    # text_embedding, pooled_text_embedding = self.text_encoder(denoising_dict)

    if self.sdxl:
        text_embedding, pooled_text_embedding = self.text_encoder(denoising_dict)
    else:
      text_embedding_dict = self.text_encoder(denoising_dict["text_input_ids_one"].squeeze(1))
      text_embedding = text_embedding_dict["last_hidden_state"]
      pooled_text_embedding = text_embedding_dict["pooler_output"]

    if real_train_dict is not None:
        real_text_embedding, real_pooled_text_embedding = self.text_encoder(real_train_dict)

        real_train_dict['text_embedding'] = real_text_embedding

        real_unet_added_conditions = {
            "time_ids": self.add_time_ids.repeat(len(real_text_embedding), 1),
            "text_embeds": real_pooled_text_embedding
        }
        real_train_dict['unet_added_conditions'] = real_unet_added_conditions

    if self.backward_simulation:
        # we overwrite the denoising timesteps 
        # note: we also use uncorrelated noise 
        clean_images, timesteps = self.sample_backward(torch.randn_like(noise), text_embedding, pooled_text_embedding) 
    else:
        clean_images = denoising_dict['images'].to(noise.device)

    noisy_image = self.noise_scheduler.add_noise(
        clean_images, noise, timesteps
    )

    # set last timestep to pure noise
    pure_noise_mask = (timesteps == (self.num_train_timesteps-1))
    noisy_image[pure_noise_mask] = noise[pure_noise_mask]

    return timesteps, text_embedding, pooled_text_embedding, real_train_dict, noisy_image