google / prompt-to-prompt

Apache License 2.0
3.07k stars 285 forks source link

DDS does not work with Stable Diffusion 3 #92

Closed carlpersson1 closed 1 month ago

carlpersson1 commented 2 months ago

Hello and thank you for the great work!

I have been interested in utilizing DDS with Stable Diffusion 3 and applying the method directly does not yield satisfactory results. Since the latent space is different in Stable Diffusion 3, as compared to 2.1 and 1.5, I have adjusted the learning rate. I have tried a lot of values for the learning rate, none of which produce good results:

Here's an image produced by using the scheduler for Stable Diffusion 3, with a learning rate of 0.005 (target prompt is "a photo of a tiger" and the source prompt is "a photo of a cat"):

005

and lr = 0.0075:

0075

both results are a far cry of what you have showcased using DDS. I also explored using the Stable Diffusion 2 scheduler, instead of the recommended scheduler for stable diffusion 3, which yields similar results (learning rates 0.01 and 0.015 respectively):

001

015

When testing I found a configuration that works surprisingly well using the stable diffusion 3 scheduler. By setting the t_min and t_max variables to 20 and 400, as well as adding weighting (in accordance with the scheduler) and doing 250 iterations I was able to generate this image:

tiger_success

However, an issue arises when the algorithm is run for longer such as ~500 iterations where the image becomes overly saturated (this is particularly an issue for trying to change seasons in an image i.e. from summer to winter, where longer training is needed for larger changes):

oversat_tiger

It would be great if you could share any insights into why this might be the case and how it could be solved, such that DDS can be adapted for use with Stable Diffusion 3 (medium)!

I include the modified code that I used to create the images (configured for the successful tiger modification):


def init_pipe(device, dtype, unet, scheduler) -> Tuple[StableDiffusion3DiffDiffPipeline, T, T]:
    with torch.inference_mode():
        alphas = torch.sqrt(scheduler.alphas_cumprod).to(device, dtype=dtype)
        sigmas = torch.sqrt(1 - scheduler.alphas_cumprod).to(device, dtype=dtype)
    for p in unet.parameters():
        p.requires_grad = False
    return unet, scheduler, alphas, sigmas

class DDSLoss:

    def __init__(self, device, pipe: StableDiffusion3DiffDiffPipeline, dtype=torch.float32):
        self.t_min = 20
        self.t_max = 400
        self.alpha_exp = 0
        self.sigma_exp = 0
        self.dtype = dtype
        scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler")
        self.unet, self.scheduler, self.alphas, self.sigmas = init_pipe(device, dtype, pipe.transformer, scheduler)
        self.scheduler = pipe.scheduler
        self.timesteps = self.scheduler.timesteps.to("cuda")
        self.prediction_type = "eps"

    def noise_input(self, z, eps=None, timestep: Optional[int] = None):
        timestep_step = None
        if timestep is None:
            b = z.shape[0]
            timestep = torch.randint(
                low=self.t_min,
                high=min(self.t_max, 1000) - 1,  # Avoid the highest timestep.
                size=(b,),
                device=z.device, dtype=torch.int)
            #timestep = self.scheduler.timesteps.to("cuda")[timestep_step * -1]
        if eps is None:
            eps = randn_tensor(z.shape, device="cuda")
        #z_t = self.scheduler.scale_noise(z, timestep, eps)
        alpha_t = self.alphas[timestep, None, None, None]
        sigma_t = self.sigmas[timestep, None, None, None]
        z_t = alpha_t * z + sigma_t * eps
        return z_t, eps, timestep, timestep_step

    def get_eps_prediction(self, z_t: T, timestep: T, text_embeddings: T, text_pool_embeddings: T, guidance_scale=7.5):

        latent_input = torch.cat([z_t] * 2)
        timestep = torch.cat([timestep] * 2)
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            e_t = self.unet(
                    hidden_states=latent_input,
                    timestep=timestep,
                    encoder_hidden_states=text_embeddings,
                    pooled_projections=text_pool_embeddings,
                    return_dict=False,
                )[0]
            e_t_uncond, e_t = e_t.chunk(2)
            e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)
        return e_t

    def get_sds_loss(self, z: T, text_embeddings: T, eps: TN = None, mask=None, t=None,
                 timestep: Optional[int] = None, guidance_scale=7.5) -> TS:
        with torch.inference_mode():
            z_t, eps, timestep = self.noise_input(z, eps=eps, timestep=timestep)
            e_t, _ = self.get_eps_prediction(z_t, timestep, text_embeddings,
                                             guidance_scale=guidance_scale)
            grad_z = (e_t - eps)
            log_loss = (grad_z ** 2).mean()
        sds_loss = grad_z.clone() * z
        del grad_z
        return sds_loss.sum() / (z.shape[2] * z.shape[3]), log_loss

    def get_dds_loss(self, z_source: T, z_target: T, text_emb_source: T, text_emb_pool_source: T, text_emb_target: T, text_emb_pool_target: T,
                            eps=None, reduction='mean', symmetric: bool = False, calibration_grad=None, timestep: Optional[int] = None,
                      guidance_scale=7.5, raw_log=False) -> TS:
        with torch.inference_mode():
            z_t_source, eps, timestep, step = self.noise_input(z_source, eps, timestep)
            z_t_target, _, _, _= self.noise_input(z_target, eps, timestep)
            eps_pred_source = self.get_eps_prediction(z_t_source,
                                                  timestep,
                                                  text_emb_source,
                                                  text_emb_pool_source,
                                                  guidance_scale=guidance_scale)
            eps_pred_target = self.get_eps_prediction(z_t_target,
                                                  timestep,
                                                  text_emb_target,
                                                  text_emb_pool_target,
                                                  guidance_scale=guidance_scale)
            grad = (eps_pred_target - eps_pred_source) * (self.scheduler.timesteps.to("cuda")[-(step + 1)] - self.scheduler.timesteps.to("cuda")[-step])
            log_loss = (grad ** 2).mean()
        loss = z_target * grad.clone().detach()
        loss = loss.sum()
        return loss, log_loss

def DDS(dds_loss, pipe, src_image, num_iters, src_prompt, src_prompt_2, src_prompt_3, tar_prompt, tar_prompt_2, tar_prompt_3, device="cuda", guidance_scale=7.5):
    with torch.no_grad():
        (
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        ) = pipe.encode_prompt(
            prompt=src_prompt,
            prompt_2=src_prompt_2,
            prompt_3=src_prompt_3,
            do_classifier_free_guidance=True,
            device=device,
            clip_skip=None,
        )

        src_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
        src_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)

        (
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        ) = pipe.encode_prompt(
            prompt=tar_prompt,
            prompt_2=tar_prompt_2,
            prompt_3=tar_prompt_3,
            do_classifier_free_guidance=True,
            device=device,
            clip_skip=None,
        )

        tar_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
        tar_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
        image = pipe.image_processor.preprocess(src_image)
        latents_src = pipe.vae.encode(image).latent_dist.sample()
        latents_src = (latents_src - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
    latents_tar = latents_src.clone()
    latents_tar.requires_grad = True

    optimizer = SGD(params=[latents_tar], lr=0.005) # accumulated learning rate instead of constant factors in code

    for i in tqdm(range(num_iters), "Training DDS...", leave=False):
        loss, log_loss = dds_loss.get_dds_loss(latents_src, latents_tar, src_prompt_embeds, src_pooled_prompt_embeds,
                                                tar_prompt_embeds, tar_pooled_prompt_embeds, guidance_scale=guidance_scale)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    latents = (latents_tar / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image.cpu().detach(), output_type="np")
    return image

if __name__ == "__main__":
    inpaint_pipe = StableDiffusion3DiffDiffPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16).to("cuda")
    dds_loss = DDSLoss("cuda", inpaint_pipe)
    image_src = torch.from_numpy(load_512(f"../prompt-to-prompt/example_images/gnochi_mirror.jpeg").transpose(2, 0, 1)).to(device="cuda", dtype=torch.float16) / 255

    dds_loss = DDSLoss("cuda", inpaint_pipe)
    image = DDS(dds_loss, inpaint_pipe, image_src, 250, src_prompt="A photo of a cat.", src_prompt_2=None, src_prompt_3=None, 
                tar_prompt="A photo of a tiger.", tar_prompt_2=None, tar_prompt_3=None, guidance_scale=7.5)

    f, axarr = plt.subplots(2,1) 
    axarr[0].imshow(image_src.cpu().detach().numpy().transpose(1, 2, 0).astype(np.float32))
    axarr[1].imshow(image[0].astype(np.float32))
    plt.show()