garibida / ReNoise-Inversion

Officail Implementation for "ReNoise: Real Image Inversion Through Iterative Noising"
https://garibida.github.io/ReNoise-Inversion/
199 stars 8 forks source link

Inversion with guidance scale > 1.0 #4

Open M4xim4l opened 6 months ago

M4xim4l commented 6 months ago

Hi,

I noticed that in the default run config, the guidance scale for reconstruction is set to 0. and the examples use a guidance scale of 1.0 for inference. I tried to set that to the usual value of 7 and get the error:

Traceback (most recent call last):
  File "/mnt/USER/ReNoise-Inversion-main/inversion_example_sd.py", line 30, in <module>
    _, inv_latent, _, all_latents = invert(input_image,
  File "/mnt/USER/ReNoise-Inversion-main/main.py", line 44, in run
    res = pipe_inversion(prompt = prompt,
  File "/mnt/USER/ReNoise-Inversion-main/src/pipes/sd_inversion_pipeline.py", line 153, in __call__
    latents = inversion_step(self,
  File "/mnt/USER/ReNoise-Inversion-main/src/renoise_inversion.py", line 146, in inversion_step
    noise_pred = noise_regularization(noise_pred, noise_pred_optimal, lambda_kl=pipe.cfg.noise_regularization_lambda_kl, lambda_ac=pipe.cfg.noise_regularization_lambda_ac, num_reg_steps=pipe.cfg.noise_regularization_num_reg_steps, num_ac_rolls=pipe.cfg.noise_regularization_num_ac_rolls, generator=generator)
  File "/mnt/USER/ReNoise-Inversion-main/src/renoise_inversion.py", line 11, in noise_regularization
    l_kld = patchify_latents_kl_divergence(_var, noise_pred_optimal)
  File "/mnt/USER/ReNoise-Inversion-main/src/renoise_inversion.py", line 69, in patchify_latents_kl_divergence
    kl = latents_kl_divergence(x0, x1).sum()
  File "/mnt/USER/ReNoise-Inversion-main/src/renoise_inversion.py", line 82, in latents_kl_divergence
    torch.log((var1 + EPSILON) / (var0 + EPSILON))
RuntimeError: The size of tensor a (512) must match the size of tensor b (256) at non-singleton dimension 0

Is it possible to use ReNoise inversion with classifier-free guidance (guidance scale > 1) and why do you use a value of 0 for guidance during the inversion and 1.0 during inference in the examples?

Thank you!

Georgefwt commented 6 months ago

I have encountered the same problem. Does the author have any instructions on how to handle situations when the guidance_scale is greater than 1?

johndpope commented 6 months ago

I jam the renoise_inversion.py + the above error into claude - it spits out this answer

Screenshot 2024-04-28 at 7 19 07 am

_The error you encountered is likely due to a mismatch in the tensor sizes when using a non-zero guidance scale during the inversion process. In the provided code, the guidance scale is set to 0 during inversion and 1.0 during inference for specific reasons. Using a guidance scale greater than 1 during inversion can be problematic because it involves splitting the batch size and performing classifier-free guidance. This requires careful handling of tensor sizes throughout the pipeline to ensure consistency. The reason for using a guidance scale of 0 during inversion is to focus solely on reconstructing the input image without any additional conditioning or guidance. The goal is to find the latent representation that closely matches the input image, and introducing guidance during this process may interfere with the reconstruction quality. On the other hand, during inference, a guidance scale of 1.0 is used to apply a small amount of conditioning to the generated image. This allows for some control over the generated output while still maintaining a close resemblance to the original image. To use a higher guidance scale (e.g., 7) during inversion, you would need to modify the code to handle the tensor sizes correctly throughout the pipeline. This would involve ensuring that the tensors are split and concatenated appropriately to match the expected dimensions at each step. Here are a few suggestions to address the issue:

Modify the unet_pass function to handle the case when pipe.do_classifier_free_guidance is True during inversion. Ensure that the latent model input and other tensors are split and concatenated correctly. Update the inversion_step function to handle the splitting and concatenation of tensors when using classifier-free guidance during inversion. Make sure that the tensor sizes match at each step, especially when performing noise regularization. Adjust the tensor sizes in the patchify_latents_kl_divergence and latents_kl_divergence functions to handle the case when the batch size is doubled due to classifier-free guidance.

It's important to note that using a high guidance scale during inversion may impact the quality of the reconstructed image, as it introduces additional conditioning that may deviate from the original input. Experimentation and fine-tuning may be necessary to find the right balance between reconstruction quality and guidance strength. If you choose to modify the code to support a higher guidance scale during inversion, be prepared to carefully debug and ensure that the tensor sizes and shapes are consistent throughout the pipeline to avoid any runtime errors._

UPDATE I then ask it to proviide code

Screenshot 2024-04-28 at 7 24 12 am
def noise_regularization(
    e_t, noise_pred_optimal, lambda_kl, lambda_ac, num_reg_steps, num_ac_rolls, generator=None, do_classifier_free_guidance=False
):
    for _outer in range(num_reg_steps):
        if lambda_kl > 0:
            _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
            if do_classifier_free_guidance:
                _var = torch.cat([_var] * 2)
                noise_pred_optimal = torch.cat([noise_pred_optimal] * 2)
            l_kld = patchify_latents_kl_divergence(_var, noise_pred_optimal)
            l_kld.backward()
            _grad = _var.grad.detach()
            _grad = torch.clip(_grad, -100, 100)
            if do_classifier_free_guidance:
                _grad = _grad.chunk(2)[0]
            e_t = e_t - lambda_kl * _grad
        if lambda_ac > 0:
            for _inner in range(num_ac_rolls):
                _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
                if do_classifier_free_guidance:
                    _var = torch.cat([_var] * 2)
                l_ac = auto_corr_loss(_var, generator=generator)
                l_ac.backward()
                _grad = _var.grad.detach() / num_ac_rolls
                if do_classifier_free_guidance:
                    _grad = _grad.chunk(2)[0]
                e_t = e_t - lambda_ac * _grad
        e_t = e_t.detach()

    return e_t

def inversion_step(
    pipe,
    z_t: torch.tensor,
    t: torch.tensor,
    prompt_embeds,
    added_cond_kwargs,
    num_renoise_steps: int = 100,
    first_step_max_timestep: int = 250,
    generator=None,
) -> torch.tensor:
    extra_step_kwargs = {}
    avg_range = pipe.cfg.average_first_step_range if t.item() < first_step_max_timestep else pipe.cfg.average_step_range
    num_renoise_steps = min(pipe.cfg.max_num_renoise_steps_first_step, num_renoise_steps) if t.item() < first_step_max_timestep else num_renoise_steps

    nosie_pred_avg = None
    noise_pred_optimal = None
    z_tp1_forward = pipe.scheduler.add_noise(pipe.z_0, pipe.noise, t.view((1))).detach()

    approximated_z_tp1 = z_t.clone()
    for i in range(num_renoise_steps + 1):

        with torch.no_grad():
            # if noise regularization is enabled, we need to double the batch size for the first step
            if pipe.cfg.noise_regularization_num_reg_steps > 0 and i == 0:
                approximated_z_tp1 = torch.cat([z_tp1_forward, approximated_z_tp1])
                prompt_embeds_in = torch.cat([prompt_embeds, prompt_embeds])
                if added_cond_kwargs is not None:
                    added_cond_kwargs_in = {}
                    added_cond_kwargs_in['text_embeds'] = torch.cat([added_cond_kwargs['text_embeds'], added_cond_kwargs['text_embeds']])
                    added_cond_kwargs_in['time_ids'] = torch.cat([added_cond_kwargs['time_ids'], added_cond_kwargs['time_ids']])
                else:
                    added_cond_kwargs_in = None
            else:
                prompt_embeds_in = prompt_embeds
                added_cond_kwargs_in = added_cond_kwargs

            noise_pred = unet_pass(pipe, approximated_z_tp1, t, prompt_embeds_in, added_cond_kwargs_in)

            # if noise regularization is enabled, we need to split the batch size for the first step
            if pipe.cfg.noise_regularization_num_reg_steps > 0 and i == 0:
                noise_pred_optimal, noise_pred = noise_pred.chunk(2)
                noise_pred_optimal = noise_pred_optimal.detach()

            # perform guidance
            if pipe.do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + pipe.guidance_scale * (noise_pred_text - noise_pred_uncond)

            # Calculate average noise
            if  i >= avg_range[0] and i < avg_range[1]:
                j = i - avg_range[0]
                if nosie_pred_avg is None:
                    nosie_pred_avg = noise_pred.clone()
                else:
                    nosie_pred_avg = j * nosie_pred_avg / (j + 1) + noise_pred / (j + 1)

        if i >= avg_range[0] or (not pipe.cfg.average_latent_estimations and i > 0):
            noise_pred = noise_regularization(noise_pred, noise_pred_optimal, lambda_kl=pipe.cfg.noise_regularization_lambda_kl, lambda_ac=pipe.cfg.noise_regularization_lambda_ac, num_reg_steps=pipe.cfg.noise_regularization_num_reg_steps, num_ac_rolls=pipe.cfg.noise_regularization_num_ac_rolls, generator=generator, do_classifier_free_guidance=pipe.do_classifier_free_guidance)

        approximated_z_tp1 = pipe.scheduler.inv_step(noise_pred, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()

    # if average latents is enabled, we need to perform an additional step with the average noise
    if pipe.cfg.average_latent_estimations and nosie_pred_avg is not None:
        nosie_pred_avg = noise_regularization(nosie_pred_avg, noise_pred_optimal, lambda_kl=pipe.cfg.noise_regularization_lambda_kl, lambda_ac=pipe.cfg.noise_regularization_lambda_ac, num_reg_steps=pipe.cfg.noise_regularization_num_reg_steps, num_ac_rolls=pipe.cfg.noise_regularization_num_ac_rolls, generator=generator, do_classifier_free_guidance=pipe.do_classifier_free_guidance)
        approximated_z_tp1 = pipe.scheduler.inv_step(nosie_pred_avg, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()

    # perform noise correction
    if pipe.cfg.perform_noise_correction:
        noise_pred = unet_pass(pipe, approximated_z_tp1, t, prompt_embeds, added_cond_kwargs)

        # perform guidance
        if pipe.do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + pipe.guidance_scale * (noise_pred_text - noise_pred_uncond)

        pipe.scheduler.step_and_update_noise(noise_pred, t, approximated_z_tp1, z_t, return_dict=False, optimize_epsilon_type=pipe.cfg.perform_noise_correction)

    return approximated_z_tp1
Georgefwt commented 6 months ago

I tested the code generated by Claude and found it cannot be used directly. In fact, I need to make the following modifications to the noise_regularization function:

def noise_regularization(
    e_t, noise_pred_optimal, lambda_kl, lambda_ac, num_reg_steps, num_ac_rolls, generator=None, do_classifier_free_guidance=False
):
    for _outer in range(num_reg_steps):
        if lambda_kl > 0:
            _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
            if do_classifier_free_guidance:
                __var = torch.cat([_var] * 2)
            l_kld = patchify_latents_kl_divergence(__var, noise_pred_optimal)
            l_kld.backward()
            _grad = _var.grad.detach()
            _grad = torch.clip(_grad, -100, 100)
            if do_classifier_free_guidance:
                _grad = _grad.chunk(2)[0]
            e_t = e_t - lambda_kl * _grad
        if lambda_ac > 0:
            for _inner in range(num_ac_rolls):
                _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
                l_ac = auto_corr_loss(_var, generator=generator)
                l_ac.backward()
                _grad = _var.grad.detach() / num_ac_rolls
                if do_classifier_free_guidance:
                    _grad = _grad.chunk(2)[0]
                e_t = e_t - lambda_ac * _grad
        e_t = e_t.detach()

    return e_t

Now the inversion can run properly, but the editability is still poor (Unable to change the lion to a tiger):

garibida commented 6 months ago

Hi,

First, I uploaded a commit that fixed the errors when using CFG greater than 1.0. Second, regarding the question about using CFG=0.0 in the config and CFG=1.0 during inference: if the CFG value is less than or equal to 1.0, the diffusion pipeline does not perform CFG at all, so it doesn't really matter.

Georgefwt commented 6 months ago

I'm sorry, I tested the code and set guidance_scale: float = 4.0 in src/config.py. This is the reconstruction result I got. Is there something I did wrong?