Open kmpartner opened 4 months ago
it is possible but you need some coding
Thank you for response. Which part of code and what kind of modification is required to enable multi-step training for sdv1.5? Is it denoising part?
i think you would need to modify this function https://github.com/tianweiy/DMD2/blob/0f8a481716539af7b2795740c9763a7d0d05b83b/main/sd_unified_model.py#L166
first remove this assert and adapt the text encoder and backward simulation code (if you need the later one). you can see this line for how to modify the text encoder https://github.com/tianweiy/DMD2/blob/0f8a481716539af7b2795740c9763a7d0d05b83b/main/sd_unified_model.py#L222
i can help reviewing your changes if you are interested to do a pull request
I updated text_embedding and pooled_text_embedding by text_encoder in prepare_denoising_data function. Is this right way to enable multi-step in sdv1.5?
Is denoising_timestep 250 in sdxl case? Is there any reasons to use 250 other than to limit 4 steps?
def prepare_denoising_data(self, denoising_dict, real_train_dict, noise):
# assert self.sdxl, "Denoising is only supported for SDXL"
indices = torch.randint(
0, self.num_denoising_step, (noise.shape[0],), device=noise.device, dtype=torch.long
)
timesteps = self.denoising_step_list.to(noise.device)[indices]
# text_embedding, pooled_text_embedding = self.text_encoder(denoising_dict)
if self.sdxl:
text_embedding, pooled_text_embedding = self.text_encoder(denoising_dict)
else:
text_embedding_dict = self.text_encoder(denoising_dict["text_input_ids_one"].squeeze(1))
text_embedding = text_embedding_dict["last_hidden_state"]
pooled_text_embedding = text_embedding_dict["pooler_output"]
if real_train_dict is not None:
real_text_embedding, real_pooled_text_embedding = self.text_encoder(real_train_dict)
real_train_dict['text_embedding'] = real_text_embedding
real_unet_added_conditions = {
"time_ids": self.add_time_ids.repeat(len(real_text_embedding), 1),
"text_embeds": real_pooled_text_embedding
}
real_train_dict['unet_added_conditions'] = real_unet_added_conditions
if self.backward_simulation:
# we overwrite the denoising timesteps
# note: we also use uncorrelated noise
clean_images, timesteps = self.sample_backward(torch.randn_like(noise), text_embedding, pooled_text_embedding)
else:
clean_images = denoising_dict['images'].to(noise.device)
noisy_image = self.noise_scheduler.add_noise(
clean_images, noise, timesteps
)
# set last timestep to pure noise
pure_noise_mask = (timesteps == (self.num_train_timesteps-1))
noisy_image[pure_noise_mask] = noise[pure_noise_mask]
return timesteps, text_embedding, pooled_text_embedding, real_train_dict, noisy_image
Is it possible to train 4steps (multi-steps) model for sd-v1.5 in this repo? I see in sdxl experiment docs, but not in sdv1.5 experiment docs.